diff --git a/azure-keyvault/azure/keyvault/__init__.py b/azure-keyvault/azure/keyvault/__init__.py index 04c2e04ae9ee..1b62621a32b1 100644 --- a/azure-keyvault/azure/keyvault/__init__.py +++ b/azure-keyvault/azure/keyvault/__init__.py @@ -11,6 +11,7 @@ from .custom import http_bearer_challenge_cache as HttpBearerChallengeCache from .custom.http_bearer_challenge import HttpBearerChallenge +from .custom.http_challenge import HttpChallenge from .custom.key_vault_client import CustomKeyVaultClient as KeyVaultClient from .custom.key_vault_id import (KeyVaultId, KeyId, @@ -20,7 +21,8 @@ CertificateOperationId, StorageAccountId, StorageSasDefinitionId) -from .custom.key_vault_authentication import KeyVaultAuthentication, KeyVaultAuthBase +from .custom.key_vault_authentication import KeyVaultAuthentication, KeyVaultAuthBase, AccessToken +from .custom.http_message_security import generate_pop_key from .version import VERSION __all__ = ['KeyVaultClient', @@ -34,8 +36,11 @@ 'StorageSasDefinitionId', 'HttpBearerChallengeCache', 'HttpBearerChallenge', + 'HttpChallenge', 'KeyVaultAuthentication', - 'KeyVaultAuthBase'] + 'KeyVaultAuthBase', + 'generate_pop_key', + 'AccessToken'] __version__ = VERSION diff --git a/azure-keyvault/azure/keyvault/custom/http_challenge.py b/azure-keyvault/azure/keyvault/custom/http_challenge.py new file mode 100644 index 000000000000..e6d75f15f158 --- /dev/null +++ b/azure-keyvault/azure/keyvault/custom/http_challenge.py @@ -0,0 +1,115 @@ +#--------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +#--------------------------------------------------------------------------------------------- + +try: + import urllib.parse as parse +except ImportError: + import urlparse as parse # pylint: disable=import-error + + +class HttpChallenge(object): + + def __init__(self, request_uri, challenge, response_headers=None): + """ Parses an HTTP WWW-Authentication Bearer challenge from a server. """ + self.source_authority = self._validate_request_uri(request_uri) + self.source_uri = request_uri + self._parameters = {} + + # get the scheme of the challenge and remove from the challenge string + trimmed_challenge = self._validate_challenge(challenge) + split_challenge = trimmed_challenge.split(' ', 1) + self.scheme = split_challenge[0] + trimmed_challenge = split_challenge[1] + + # split trimmed challenge into comma-separated name=value pairs. Values are expected + # to be surrounded by quotes which are stripped here. + for item in trimmed_challenge.split(','): + # process name=value pairs + comps = item.split('=') + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value + + # minimum set of parameters + if not self._parameters: + raise ValueError('Invalid challenge parameters') + + # must specify authorization or authorization_uri + if 'authorization' not in self._parameters and 'authorization_uri' not in self._parameters: + raise ValueError('Invalid challenge parameters') + + # if the response headers were supplied + if response_headers: + # get the message signing key and message key encryption key from the headers + self.server_signature_key = response_headers.get('x-ms-message-signing-key', None) + self.server_encryption_key = response_headers.get('x-ms-message-encryption-key', None) + + def is_bearer_challenge(self): + """ Tests whether the HttpChallenge a Bearer challenge. + rtype: bool """ + if not self.scheme: + return False + + return self.scheme.lower() == 'bearer' + + def is_pop_challenge(self): + """ Tests whether the HttpChallenge is a proof of possession challenge. + rtype: bool """ + if not self.scheme: + return False + + return self.scheme.lower() == 'pop' + + def get_value(self, key): + return self._parameters.get(key) + + def get_authorization_server(self): + """ Returns the URI for the authorization server if present, otherwise empty string. """ + value = '' + for key in ['authorization_uri', 'authorization']: + value = self.get_value(key) or '' + if value: + break + return value + + def get_resource(self): + """ Returns the resource if present, otherwise empty string. """ + return self.get_value('resource') or '' + + def get_scope(self): + """ Returns the scope if present, otherwise empty string. """ + return self.get_value('scope') or '' + + def supports_pop(self): + """ Returns True if challenge supports pop token auth else False """ + return self._parameters.get('supportspop', '').lower() == 'true' + + def supports_message_protection(self): + """ Returns True if challenge vault supports message protection """ + return self.supports_pop() and self.server_encryption_key and self.server_signature_key + + def _validate_challenge(self, challenge): + """ Verifies that the challenge is a valid auth challenge and returns the key=value pairs. """ + if not challenge: + raise ValueError('Challenge cannot be empty') + + return challenge.strip() + + # pylint: disable=no-self-use + def _validate_request_uri(self, uri): + """ Extracts the host authority from the given URI. """ + if not uri: + raise ValueError('request_uri cannot be empty') + + uri = parse.urlparse(uri) + if not uri.netloc: + raise ValueError('request_uri must be an absolute URI') + + if uri.scheme.lower() not in ['http', 'https']: + raise ValueError('request_uri must be HTTP or HTTPS') + + return uri.netloc diff --git a/azure-keyvault/azure/keyvault/custom/http_message_security.py b/azure-keyvault/azure/keyvault/custom/http_message_security.py new file mode 100644 index 000000000000..b2f318a25b3b --- /dev/null +++ b/azure-keyvault/azure/keyvault/custom/http_message_security.py @@ -0,0 +1,192 @@ +#--------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +#--------------------------------------------------------------------------------------------- + +import json +import time +import os +from .internal import _a128cbc_hs256_encrypt, _a128cbc_hs256_decrypt, _JwsHeader, _JwsObject, \ + _JweHeader, _JweObject, _str_to_b64url, _bstr_to_b64url, _b64_to_bstr, _RsaKey + + +def generate_pop_key(): + """ + Generates a key which can be used for Proof Of Possession token authentication. + :return: + """ + return _RsaKey.generate() + + +class HttpMessageSecurity(object): + """ + Used for message authorization, encryption and decrtyption. + + This class is intended for internal use only. Details are subject to non-compatible changes, consumers of the + azure-keyvault module should not take dependencies on this class or its current implementation. + """ + def __init__(self, client_security_token=None, + client_signature_key=None, + client_encryption_key=None, + server_signature_key=None, + server_encryption_key=None): + self.client_security_token = client_security_token + self.client_signature_key = client_signature_key + self.client_encryption_key = client_encryption_key + self.server_signature_key = server_signature_key + self.server_encryption_key = server_encryption_key + + def protect_request(self, request): + """ + Adds authorization header, and encrypts and signs the request if supported on the specific request. + :param request: unprotected request to apply security protocol + :return: protected request with appropriate security protocal applied + """ + # Setup the auth header on the request + # Due to limitations in the service we hard code the auth scheme to 'Bearer' as the service will fail with any + # other scheme or a different casing such as 'bearer', once this is fixed the following line should be replaced: + # request.headers['Authorization'] = '{} {}'.format(auth[0], auth[1]) + request.headers['Authorization'] = '{} {}'.format('Bearer', self.client_security_token) + + # if the current message security doesn't support message protection, or the body is empty + # skip protection and return the original request + if not self.supports_protection() or len(request.body) == 0: + return request + + plain_text = request.body + + # if the client encryption key is specified add it to the body of the request + if self.client_encryption_key: + # note that this assumes that the body is already json and not simple string content + # this is true for all requests which currently support message encryption, but might + # need to be revisited when the types of + body_dict = json.loads(plain_text) + body_dict['rek'] = {'jwk': self.client_encryption_key.to_jwk().serialize()} + plain_text = json.dumps(body_dict).encode(encoding='utf8') + + # build the header for the jws body + jws_header = _JwsHeader() + jws_header.alg = 'RS256' + jws_header.kid = self.client_signature_key.kid + jws_header.at = self.client_security_token + jws_header.ts = int(time.time()) + jws_header.typ = 'PoP' + + jws = _JwsObject() + + jws.protected = jws_header.to_compact_header() + jws.payload = self._protect_payload(plain_text) + data = (jws.protected + '.' + jws.payload).encode('ascii') + jws.signature = _bstr_to_b64url(self.client_signature_key.sign(data)) + + request.headers['Content-Type'] = 'application/jose+json' + + request.prepare_body(data=jws.to_flattened_jws(), files=None) + + return request + + def unprotect_response(self, response, **kwargs): + """ + Removes protection from the specified response + :param request: response from the key vault service + :return: unprotected response with any security protocal encryption removed + """ + body = response.content + # if the current message security doesn't support message protection, the body is empty, or the request failed + # skip protection and return the original response + if not self.supports_protection() or len(response.content) == 0 or response.status_code != 200: + return response + + # ensure the content-type is application/jose+json + if 'application/jose+json' not in response.headers.get('content-type', '').lower(): + raise ValueError('Invalid protected response') + + # deserialize the response into a JwsObject, using response.text so requests handles the encoding + jws = _JwsObject().deserialize(body) + + # deserialize the protected header + jws_header = _JwsHeader.from_compact_header(jws.protected) + + # ensure the jws signature kid matches the key from original challenge + # and the alg matches expected signature alg + if jws_header.kid != self.server_signature_key.kid \ + or jws_header.alg != 'RS256': + raise ValueError('Invalid protected response') + + # validate the signature of the jws + data = (jws.protected + '.' + jws.payload).encode('ascii') + # verify will raise an InvalidSignature exception if the signature doesn't match + self.server_signature_key.verify(signature=_b64_to_bstr(jws.signature), data=data) + + # get the unprotected response body + decrypted = self._unprotect_payload(jws.payload) + + response._content = decrypted + response.headers['Content-Type'] = 'application/json' + + return response + + def supports_protection(self): + """ + Determines if the the current HttpMessageSecurity object supports the message protection protocol. + :return: True if the current object supports protection, otherwise False + """ + return self.client_signature_key \ + and self.client_encryption_key \ + and self.server_signature_key \ + and self.server_encryption_key + + def _protect_payload(self, plaintext): + # create the jwe header for the payload + kek = self.server_encryption_key + jwe_header = _JweHeader() + jwe_header.alg = 'RSA-OAEP' + jwe_header.kid = kek.kid + jwe_header.enc = 'A128CBC-HS256' + + # create the jwe object + jwe = _JweObject() + jwe.protected = jwe_header.to_compact_header() + + # generate the content encryption key and iv + cek = os.urandom(32) + iv = os.urandom(16) + jwe.iv = _bstr_to_b64url(iv) + # wrap the cek using the server encryption key + wrapped = _bstr_to_b64url(kek.encrypt(cek)) + jwe.encrypted_key = wrapped + + # encrypt the plaintext body with the cek using the protected header + # as the authdata to get the ciphertext and the authtag + ciphertext, tag = _a128cbc_hs256_encrypt(cek, iv, plaintext, jwe.protected.encode('ascii')) + + jwe.ciphertext = _bstr_to_b64url(ciphertext) + jwe.tag = _bstr_to_b64url(tag) + + # flatten and encode the jwe for the final jws payload content + flat = jwe.to_flattened_jwe() + return _str_to_b64url(flat) + + def _unprotect_payload(self, payload): + # deserialize the payload + jwe = _JweObject().deserialize_b64(payload) + + # deserialize the payload header + jwe_header = _JweHeader.from_compact_header(jwe.protected) + + # ensure the kid matches the specified client encryption key + # and the key wrap alg and the data encryption enc match the expected + if self.client_encryption_key.kid != jwe_header.kid \ + or jwe_header.alg != 'RSA-OAEP' \ + or jwe_header.enc != 'A128CBC-HS256': + raise ValueError('Invalid protected response') + + # unwrap the cek using the client encryption key + cek = self.client_encryption_key.decrypt(_b64_to_bstr(jwe.encrypted_key)) + + # decrypt the cipher text to get the unprotected body content + return _a128cbc_hs256_decrypt(cek, + _b64_to_bstr(jwe.iv), + _b64_to_bstr(jwe.ciphertext), + jwe.protected.encode('ascii'), + _b64_to_bstr(jwe.tag)) \ No newline at end of file diff --git a/azure-keyvault/azure/keyvault/custom/internal.py b/azure-keyvault/azure/keyvault/custom/internal.py new file mode 100644 index 000000000000..f3992e8bde57 --- /dev/null +++ b/azure-keyvault/azure/keyvault/custom/internal.py @@ -0,0 +1,410 @@ +#--------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +#--------------------------------------------------------------------------------------------- + +import json +import uuid +import codecs +from base64 import b64encode, b64decode +import cryptography +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateNumbers, RSAPublicNumbers, \ + generate_private_key, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp, RSAPrivateKey, RSAPublicKey +from cryptography.hazmat.primitives.asymmetric import padding as asym_padding +from cryptography.hazmat.primitives import hashes, constant_time, padding, hmac + +from ..models import JsonWebKey + +def _a128cbc_hs256_encrypt(key, iv, plaintext, authdata): + if not key or not len(key) >= 32: + raise ValueError('key must be at least 256 bits for algorithm "A128CBC-HS256"') + if not iv or len(iv) != 16: + raise ValueError('iv must be 128 bits for algorithm "A128CBC-HS256"') + if not plaintext: + raise ValueError('plaintext must be specified') + if not authdata: + raise ValueError('authdata must be specified') + + # get the hmac key and the aes key from the specified key + hmac_key = key[:16] + aes_key = key[16:32] + + # calculate the length of authdata and store as bytes + auth_data_length = _int_to_bigendian_8_bytes(len(authdata) * 8) + + # pad the plaintext with pkcs7 + padder = padding.PKCS7(128).padder() + plaintext = padder.update(plaintext) + padder.finalize() + + # create the cipher and encrypt the plaintext + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend()) + encryptor = cipher.encryptor() + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + + # get the data to hash with HMAC, hash the data and take the first 16 bytes + hashdata = authdata + iv + ciphertext + auth_data_length + hmac_hash = hmac.HMAC(hmac_key, hashes.SHA256(), backend=default_backend()) + hmac_hash.update(hashdata) + tag = hmac_hash.finalize()[:16] + + return ciphertext, tag + + +def _a128cbc_hs256_decrypt(key, iv, ciphertext, authdata, authtag): + if not key or not len(key) >= 32: + raise ValueError('key must be at least 256 bits for algorithm "A128CBC-HS256"') + if not iv or len(iv) != 16: + raise ValueError('iv must be 128 bits for algorithm "A128CBC-HS256"') + if not ciphertext: + raise ValueError('ciphertext must be specified') + if not authdata: + raise ValueError('authdata must be specified') + if not authtag or len(authtag) != 16: + raise ValueError('authtag must be be 128 bits for algorithm "A128CBC-HS256"') + + hmac_key = key[:16] + aes_key = key[16:32] + auth_data_length = _int_to_bigendian_8_bytes(len(authdata) * 8) + + # ensure the authtag is the expected length for SHA256 hash + if not len(authtag) == 16: + raise ValueError('invalid tag') + + hashdata = authdata + iv + ciphertext + auth_data_length + hmac_hash = hmac.HMAC(hmac_key, hashes.SHA256(), backend=default_backend()) + hmac_hash.update(hashdata) + tag = hmac_hash.finalize()[:16] + + if not constant_time.bytes_eq(tag, authtag): + raise ValueError('"ciphertext" is not authentic') + + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(iv), backend=default_backend()) + decryptor = cipher.decryptor() + plaintext = decryptor.update(ciphertext) + decryptor.finalize() + + # unpad the decrypted plaintext + padder = padding.PKCS7(128).unpadder() + plaintext = padder.update(plaintext) + padder.finalize() + + return plaintext + + +def _bytes_to_int(b): + if not b or not isinstance(b, bytes): + raise ValueError('b must be non-empty byte string') + + return int(codecs.encode(b, 'hex'), 16) + + +def _int_to_bytes(i): + h = hex(i) + if len(h) > 1 and h[0:2] == '0x': + h = h[2:] + + # need to strip L in python 2.x + h = h.strip('L') + + if len(h) % 2: + h = '0' + h + return codecs.decode(h, 'hex') + + +def _bstr_to_b64url(bstr, **kwargs): + """Serialize bytes into base-64 string. + :param str: Object to be serialized. + :rtype: str + """ + encoded = b64encode(bstr).decode() + return encoded.strip('=').replace('+', '-').replace('/', '_') + + +def _str_to_b64url(s, **kwargs): + """Serialize str into base-64 string. + :param str: Object to be serialized. + :rtype: str + """ + return _bstr_to_b64url(s.encode(encoding='utf8')) + + +def _b64_to_bstr(b64str): + """Deserialize base64 encoded string into string. + :param str b64str: response string to be deserialized. + :rtype: bytearray + :raises: TypeError if string format invalid. + """ + padding = '=' * (3 - (len(b64str) + 3) % 4) + b64str = b64str + padding + encoded = b64str.replace('-', '+').replace('_', '/') + return b64decode(encoded) + + +def _b64_to_str(b64str): + """Deserialize base64 encoded string into string. + :param str b64str: response string to be deserialized. + :rtype: str + :raises: TypeError if string format invalid. + """ + return _b64_to_bstr(b64str).decode('utf8') + + +def _int_to_bigendian_8_bytes(i): + b = _int_to_bytes(i) + + if len(b) > 8: + raise ValueError('the specified integer is to large to be represented by 8 bytes') + + if len(b) < 8: + b = (b'\0' * (8 - len(b))) + b + + return b + + +class _JoseObject(object): + + def deserialize(self, s): + d = json.loads(s) + self.__dict__ = d + return self + + def deserialize_b64(self, s): + self.deserialize(_b64_to_str(s)) + return self + + def serialize(self): + return json.dumps(self.__dict__) + + def serialize_b64url(self): + return _str_to_b64url(self.serialize()) + + +class _JoseHeader(_JoseObject): + + def to_compact_header(self): + return _str_to_b64url(json.dumps(self.__dict__)) + + +class _JweHeader(_JoseHeader): + def __init__(self, alg=None, kid=None, enc=None): + self.alg = alg + self.kid = kid + self.enc = enc + + @staticmethod + def from_compact_header(compact): + header = _JweHeader() + header.__dict__ = json.loads(_b64_to_str(compact)) + return header + + +class _JweObject(_JoseObject): + def __init__(self): + self.protected = None + self.encrypted_key = None + self.iv = None + self.ciphertext = None + self.tag = None + + def to_flattened_jwe(self): + if not (self.protected, self.encrypted_key, self.iv, self.ciphertext, self.tag): + raise ValueError('JWE is not complete.') + + return json.dumps(self.__dict__) + + +class _JwsHeader(_JoseHeader): + def __init__(self): + self.alg = None + self.kid = None + self.at = None + self.ts = None + self.p = None + self.typ = None + + @staticmethod + def from_compact_header(compact): + header = _JwsHeader() + header.__dict__ = json.loads(_b64_to_str(compact)) + return header + + +class _JwsObject(_JoseObject): + def __init__(self): + self.protected = None + self.payload = None + self.signature = None + + def to_flattened_jws(self): + if not (self.protected, self.payload, self.signature): + raise ValueError('JWS is not complete.') + + return json.dumps(self.__dict__) + + + +def _default_encryption_padding(): + return asym_padding.OAEP(mgf=asym_padding.MGF1(algorithm=hashes.SHA1()), algorithm=hashes.SHA1(), label=None) + + +def _default_signature_padding(): + return asym_padding.PKCS1v15() + + +def _default_signature_algorithm(): + return hashes.SHA256() + + +class _RsaKey(object): + PUBLIC_KEY_DEFAULT_OPS = ['encrypt', 'wrapKey', 'verify'] + PRIVATE_KEY_DEFAULT_OPS = ['encrypt', 'decrypt', 'wrapKey', 'unwrapKey', 'verify', 'sign'] + + def __init__(self): + self.kid = None + self.kty = None + self.key_ops = None + self._rsa_impl = None + + @property + def n(self): + return _int_to_bytes(self._public_key_material().n) + + @property + def e(self): + return _int_to_bytes(self._public_key_material().e) + + @property + def q(self): + return _int_to_bytes(self._private_key_material().q) if self.is_private_key() else None + + @property + def p(self): + return _int_to_bytes(self._private_key_material().p) if self.is_private_key() else None + + @property + def d(self): + return _int_to_bytes(self._private_key_material().d) if self.is_private_key() else None + + @property + def dq(self): + return _int_to_bytes(self._private_key_material().dmq1) if self.is_private_key() else None + + @property + def dp(self): + return _int_to_bytes(self._private_key_material().dmp1) if self.is_private_key() else None + + @property + def qi(self): + return _int_to_bytes(self._private_key_material().iqmp) if self.is_private_key() else None + + @property + def private_key(self): + return self._rsa_impl if self.is_private_key() else None + + @property + def public_key(self): + return self._rsa_impl.public_key() if self.is_private_key() else self._rsa_impl + + @staticmethod + def generate(kid=None, kty='RSA', size=2048, e=65537): + key = _RsaKey() + key.kid = kid or str(uuid.uuid4()) + key.kty = kty + key.key_ops = _RsaKey.PRIVATE_KEY_DEFAULT_OPS + key._rsa_impl = generate_private_key(public_exponent=e, + key_size=size, + backend=cryptography.hazmat.backends.default_backend()) + return key + + @staticmethod + def from_jwk_str(s): + jwk_dict = json.loads(s) + jwk = JsonWebKey.from_dict(jwk_dict) + return _RsaKey.from_jwk(jwk) + + @staticmethod + def from_jwk(jwk): + if not isinstance(jwk, JsonWebKey): + raise TypeError('The specified jwk must be a JsonWebKey') + + if jwk.kty != 'RSA' and jwk.kty != 'RSA-HSM': + raise ValueError('The specified jwk must have a key type of "RSA" or "RSA-HSM"') + + if not jwk.n or not jwk.e: + raise ValueError('Invalid RSA jwk, both n and e must be have values') + + rsa_key = _RsaKey() + rsa_key.kid = jwk.kid + rsa_key.kty = jwk.kty + rsa_key.key_ops = jwk.key_ops + + pub = RSAPublicNumbers(n=_bytes_to_int(jwk.n), e=_bytes_to_int(jwk.e)) + + # if the private key values are specified construct a private key + # only the secret primes and private exponent are needed as other fields can be calculated + if jwk.p and jwk.q and jwk.d: + # convert the values of p, q, and d from bytes to int + p = _bytes_to_int(jwk.p) + q = _bytes_to_int(jwk.q) + d = _bytes_to_int(jwk.d) + + # convert or compute the remaining private key numbers + dmp1 = _bytes_to_int(jwk.dp) if jwk.dp else rsa_crt_dmp1(private_exponent=d, p=p) + dmq1 = _bytes_to_int(jwk.dq) if jwk.dq else rsa_crt_dmq1(private_exponent=d, q=q) + iqmp = _bytes_to_int(jwk.qi) if jwk.qi else rsa_crt_iqmp(p=p, q=q) + + # create the private key from the jwk key values + priv = RSAPrivateNumbers(p=p, q=q, d=d, dmp1=dmp1, dmq1=dmq1, iqmp=iqmp, public_numbers=pub) + key_impl = priv.private_key(cryptography.hazmat.backends.default_backend()) + + # if the necessary private key values are not specified create the public key + else: + key_impl = pub.public_key(cryptography.hazmat.backends.default_backend()) + + rsa_key._rsa_impl = key_impl + + return rsa_key + + def to_jwk(self, include_private=False): + jwk = JsonWebKey(kid=self.kid, + kty=self.kty, + key_ops=self.key_ops if include_private else _RsaKey.PUBLIC_KEY_DEFAULT_OPS, + n=self.n, + e=self.e) + + if include_private: + jwk.q = self.q + jwk.p = self.p + jwk.d = self.d + jwk.dq = self.dq + jwk.dp = self.dp + jwk.qi = self.qi + + return jwk + + def encrypt(self, plaintext, padding=_default_encryption_padding()): + return self.public_key.encrypt(plaintext, padding) + + def decrypt(self, ciphertext, padding=_default_encryption_padding()): + if not self.is_private_key(): + raise NotImplementedError('The current RsaKey does not support decrypt') + + return self.private_key.decrypt(ciphertext, padding) + + def sign(self, data, padding=_default_signature_padding(), algorithm=_default_signature_algorithm()): + if not self.is_private_key(): + raise NotImplementedError('The current RsaKey does not support sign') + + return self.private_key.sign(data, padding, algorithm) + + def verify(self, signature, data, padding=_default_signature_padding(), algorithm=_default_signature_algorithm()): + return self.public_key.verify(signature, data, padding, algorithm) + + def is_private_key(self): + return isinstance(self._rsa_impl, RSAPrivateKey) + + def _public_key_material(self): + return self.public_key.public_numbers() + + def _private_key_material(self): + return self.private_key.private_numbers() if self.private_key else None diff --git a/azure-keyvault/azure/keyvault/custom/key_vault_authentication.py b/azure-keyvault/azure/keyvault/custom/key_vault_authentication.py index faf7b16e052e..d990036eee4b 100644 --- a/azure-keyvault/azure/keyvault/custom/key_vault_authentication.py +++ b/azure-keyvault/azure/keyvault/custom/key_vault_authentication.py @@ -5,11 +5,30 @@ import threading import requests +import inspect +from collections import namedtuple from requests.auth import AuthBase from requests.cookies import extract_cookies_to_jar -from azure.keyvault import HttpBearerChallenge +from azure.keyvault import HttpChallenge from azure.keyvault import HttpBearerChallengeCache as ChallengeCache from msrest.authentication import OAuthTokenAuthentication +from .http_message_security import HttpMessageSecurity +from .internal import _RsaKey + + +AccessToken = namedtuple('AccessToken', ['scheme', 'token', 'key']) +AccessToken.__new__.__defaults__ = ('Bearer', None, None) + +_message_protection_supported_methods = ['sign', 'verify', 'encrypt', 'decrypt', 'wrapkey', 'unwrapkey'] + + +def _message_protection_supported(challenge, request): + # right now only specific key operations are supported so return true only + # if the vault supports message protection, the request is to the keys collection + # and the requested operation supports it + return challenge.supports_message_protection() \ + and '/keys/' in request.url \ + and request.url.split('?')[0].strip('/').split('/')[-1].lower() in _message_protection_supported_methods class KeyVaultAuthBase(AuthBase): @@ -22,15 +41,25 @@ def __init__(self, authorization_callback): Creates a new KeyVaultAuthBase instance used for handling authentication challenges, by hooking into the request AuthBase extension model. :param authorization_callback: A callback used to provide authentication credentials to the key vault data service. - This callback should take three str arguments: authorization uri, resource, and scope, and return - a tuple of (token type, access token). + This callback should take four str arguments: authorization uri, resource, scope, and scheme, and return + an AccessToken + return AccessToken(scheme=token['token_type'], token=token['access_token']) + Note: for backward compatibility a tuple of the scheme and token can also be returned. return token['token_type'], token['access_token'] """ - self._callback = authorization_callback + self._user_callback = authorization_callback + self._callback = self._auth_callback_compat self._token = None self._thread_local = threading.local() self._thread_local.pos = None self._thread_local.auth_attempted = False + self._thread_local.orig_body = None + + # for backwards compatibility we need to support callbacks which don't accept the scheme + def _auth_callback_compat(self, server, resource, scope, scheme): + return self._user_callback(server, resource, scope) \ + if len(inspect.getargspec(self._user_callback).args) == 3 \ + else self._user_callback(server, resource, scope, scheme) def __call__(self, request): """ @@ -43,28 +72,30 @@ def __call__(self, request): if self._callback: challenge = ChallengeCache.get_challenge_for_url(request.url) if challenge: - # if challenge cached, use the authorization_callback to retrieve token and update the request - self.set_authorization_header(request, challenge) + # if challenge cached get the message security + security = self._get_message_security(request, challenge) + # protect the request + security.protect_request(request) + # register a response hook to unprotect the response + request.register_hook('response', security.unprotect_response) else: - # if the challenge is not cached we will let the request proceed without the auth header so we - # get back the proper challenge in response. We register a callback to handle the response 401 response. - try: - self._thread_local.pos = request.body.tell() - except AttributeError: - self._thread_local.pos = None - + # if the challenge is not cached we will strip the body and proceed without the auth header so we + # get back the auth challenge for the request + self._thread_local.orig_body = request.body + request.body = '' + request.headers['Content-Length'] = 0 + request.register_hook('response', self._handle_401) + request.register_hook('response', self._handle_redirect) self._thread_local.auth_attempted = False - request.register_hook('response', self.handle_401) - request.register_hook('response', self.handle_redirect) return request - def handle_redirect(self, r, **kwargs): + def _handle_redirect(self, r, **kwargs): """Reset auth_attempted on redirects.""" if r.is_redirect: self._thread_local.auth_attempted = False - def handle_401(self, response, **kwargs): + def _handle_401(self, response, **kwargs): """ Takes the response authenticates and resends if neccissary :return: The final response to the authenticated request @@ -75,28 +106,26 @@ def handle_401(self, response, **kwargs): self._thread_local.auth_attempted = False return response - auth_header = response.headers.get('www-authenticate', '') - - # if the response auth header is not a bearer challenge do not auth and return response - if not HttpBearerChallenge.is_bearer_challenge(auth_header): - self._thread_local.auth_attempted = False - return response - # If we've already attempted to auth for this request once, do not auth and return response if self._thread_local.auth_attempted: self._thread_local.auth_attempted = False return response + auth_header = response.headers.get('www-authenticate', '') + # Otherwise authenticate and retry the request self._thread_local.auth_attempted = True - if self._thread_local.pos is not None: - # Rewind the file position indicator of the body to where - # it was to resend the request. - response.request.body.seek(self._thread_local.pos) + # parse the challenge + challenge = HttpChallenge(response.request.url, auth_header, response.headers) + + # bearer and PoP are the only authentication schemes supported at this time + # if the response auth header is not a bearer challenge or pop challange do not auth and return response + if not (challenge.is_bearer_challenge() or challenge.is_pop_challenge()): + self._thread_local.auth_attempted = False + return response # add the challenge to the cache - challenge = HttpBearerChallenge(response.request.url, auth_header) ChallengeCache.set_challenge_for_url(response.request.url, challenge) # Consume content and release the original connection @@ -106,28 +135,51 @@ def handle_401(self, response, **kwargs): # copy the request to resend prep = response.request.copy() + + if self._thread_local.orig_body is not None: + # replace the body with the saved body + prep.prepare_body(data=self._thread_local.orig_body, files=None) + extract_cookies_to_jar(prep._cookies, response.request, response.raw) prep.prepare_cookies(prep._cookies) - # setup the auth header on the copied request - self.set_authorization_header(prep, challenge) + security = self._get_message_security(prep, challenge) - # resend the request with proper authentication + # auth and protect the prepped request message + security.protect_request(prep) + + # resend the request with proper authentication and message protection _response = response.connection.send(prep, **kwargs) _response.history.append(response) _response.request = prep + + # unprotected the response + security.unprotect_response(_response) + return _response - def set_authorization_header(self, request, challenge): - auth = self._callback( - challenge.get_authorization_server(), - challenge.get_resource(), - challenge.get_scope()) + def _get_message_security(self, request, challenge): + scheme = challenge.scheme + + # if the given request can be protected ensure the scheme is PoP so the proper access token is requested + if _message_protection_supported(challenge, request): + scheme = 'PoP' - # Due to limitations in the service we hard code the auth scheme to 'Bearer' as the service will fail with any other - # scheme or a different casing such as 'bearer', once this is fixed the following line should be replace with: - # request.headers['Authorization'] = '{} {}'.format(auth[0], auth[1]) - request.headers['Authorization'] = '{} {}'.format('Bearer', auth[1]) + # use the authentication_callback to get the token and create the message security + token = AccessToken(*self._callback(challenge.get_authorization_server(), + challenge.get_resource(), + challenge.get_scope(), + scheme)) + security = HttpMessageSecurity(client_security_token=token.token) + + # if the given request can be protected add the appropriate keys to the message security + if scheme == 'PoP': + security.client_signature_key = token.key + security.client_encryption_key = _RsaKey.generate() + security.server_encryption_key = _RsaKey.from_jwk_str(challenge.server_encryption_key) + security.server_signature_key = _RsaKey.from_jwk_str(challenge.server_signature_key) + + return security class KeyVaultAuthentication(OAuthTokenAuthentication): @@ -163,12 +215,12 @@ def __init__(self, authorization_callback=None, credentials=None): self._credentials = credentials if not authorization_callback: - def auth_callback(server, resource, scope): + def auth_callback(server, resource, scope, scheme): if self._credentials.resource != resource: self._credentials.resource = resource self._credentials.set_token() token = self._credentials.token - return token['token_type'], token['access_token'] + return AccessToken(scheme=token['token_type'], token=token['access_token'], key=None) authorization_callback = auth_callback diff --git a/azure-keyvault/azure/keyvault/version.py b/azure-keyvault/azure/keyvault/version.py index 3b613fee836a..7cdcf7599b55 100755 --- a/azure-keyvault/azure/keyvault/version.py +++ b/azure-keyvault/azure/keyvault/version.py @@ -9,5 +9,5 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "0.3.7" +VERSION = "1.0.0a1" diff --git a/azure-keyvault/setup.py b/azure-keyvault/setup.py index c0f336b80b11..fe80673dae8b 100644 --- a/azure-keyvault/setup.py +++ b/azure-keyvault/setup.py @@ -81,6 +81,8 @@ 'msrestazure>=0.4.15', 'msrest>=0.4.17' 'azure-common~=1.1.5', + 'cryptography>=2.1.4', + 'requests>=2.18.4' ], cmdclass=cmdclass ) diff --git a/azure-keyvault/tests/test_internal.py b/azure-keyvault/tests/test_internal.py new file mode 100644 index 000000000000..6ac5ccfb48dd --- /dev/null +++ b/azure-keyvault/tests/test_internal.py @@ -0,0 +1,281 @@ +import unittest +import os +import random +import string +import json +import uuid +import time +from azure.keyvault.custom.internal import _bytes_to_int, _int_to_bytes, _int_to_bigendian_8_bytes, \ + _bstr_to_b64url, _b64_to_bstr, _b64_to_str, _str_to_b64url, _a128cbc_hs256_decrypt, _a128cbc_hs256_encrypt, \ + _RsaKey, _JwsHeader, _JweHeader, _JwsObject, _JweObject + + +class EncodingTests(unittest.TestCase): + def test_int_byte_conversion(self): + # generate a random byte + b = os.urandom(1) + i = _bytes_to_int(b) + self._assert_bytes_significantly_equal(b, _int_to_bytes(i)) + + # generate a random number of random bytes + b = os.urandom(random.randint(1, 32)) + i = _bytes_to_int(b) + self._assert_bytes_significantly_equal(b, _int_to_bytes(i)) + + # generate random 4096 bits (4k key) + b = os.urandom(512) + i = _bytes_to_int(b) + self._assert_bytes_significantly_equal(b, _int_to_bytes(i)) + + # + b = b'\x00\x00\x00\x01' + i = _bytes_to_int(b) + self._assert_bytes_significantly_equal(b, _int_to_bytes(i)) + + b = b'' + with self.assertRaises(ValueError): + _bytes_to_int(b) + + b = None + with self.assertRaises(ValueError): + _bytes_to_int(b) + + def test_int_to_bigendian_8_bytes(self): + i = 0xFFFFFFFFFFFFFFFF + b = _int_to_bigendian_8_bytes(i) + self.assertEqual(b, b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF') + + i = 0 + b = _int_to_bigendian_8_bytes(i) + self.assertEqual(b, b'\x00\x00\x00\x00\x00\x00\x00\x00') + + i = random.randint(1, 0xFFFFFFFFFFFFFFFF) + b = _int_to_bigendian_8_bytes(i) + self.assertEqual(len(b), 8) + self.assertEqual(i, _bytes_to_int(b)) + + i = random.randint(0xFFFFFFFFFFFFFFFF01, 0xFFFFFFFFFFFFFFFFFF) + with self.assertRaises(ValueError): + _int_to_bigendian_8_bytes(i) + + def test_bstr_encode_decode(self): + b = b'' + b64 = _bstr_to_b64url(b) + self.assertEqual(b, _b64_to_bstr(b64)) + + b = os.urandom(1) + b64 = _bstr_to_b64url(b) + self.assertEqual(b, _b64_to_bstr(b64)) + + b = os.urandom(random.randint(2, 32)) + b64 = _bstr_to_b64url(b) + self.assertEqual(b, _b64_to_bstr(b64)) + + b = os.urandom(512) + b64 = _bstr_to_b64url(b) + self.assertEqual(b, _b64_to_bstr(b64)) + + def test_str_encode_decode(self): + s = '' + b64 = _str_to_b64url(s) + self.assertEqual(s, _b64_to_str(b64)) + + s = self._random_str(1) + b64 = _str_to_b64url(s) + self.assertEqual(s, _b64_to_str(b64)) + + s = self._random_str(random.randint(2, 32)) + b64 = _str_to_b64url(s) + self.assertEqual(s, _b64_to_str(b64)) + + s = self._random_str(4096) + b64 = _str_to_b64url(s) + self.assertEqual(s, _b64_to_str(b64)) + + def test_a128cbc_hs256_encrypt_decrypt(self): + key = os.urandom(32) + iv = os.urandom(16) + + plain_text = os.urandom(random.randint(1024, 4096)) + auth_data = os.urandom(random.randint(128, 512)) + + cipher_text, auth_tag = _a128cbc_hs256_encrypt(key, iv, plain_text, auth_data) + self.assertEqual(plain_text, _a128cbc_hs256_decrypt(key, iv, cipher_text, auth_data, auth_tag)) + + def test_a128cbc_hs256_encrypt_error(self): + key = os.urandom(32) + iv = os.urandom(16) + plain_text = os.urandom(random.randint(1024, 4096)) + auth_data = os.urandom(random.randint(128, 512)) + + with self.assertRaises(ValueError): + # key not specified + _a128cbc_hs256_encrypt(key=None, iv=iv, plaintext=plain_text, authdata=auth_data) + _a128cbc_hs256_encrypt(key=b'', iv=iv, plaintext=plain_text, authdata=auth_data) + # key insufficient len + _a128cbc_hs256_encrypt(key=os.urandom(31), iv=iv, plaintext=plain_text, authdata=auth_data) + + # iv not specified + _a128cbc_hs256_encrypt(key=key, iv=None, plaintext=plain_text, authdata=auth_data) + _a128cbc_hs256_encrypt(key=key, iv=b'', plaintext=plain_text, authdata=auth_data) + # iv incorrect len + _a128cbc_hs256_encrypt(key=key, iv=os.urandom(15), plaintext=plain_text, authdata=auth_data) + _a128cbc_hs256_encrypt(key=key, iv=os.urandom(17), plaintext=plain_text, authdata=auth_data) + + # plaintext not specified + _a128cbc_hs256_encrypt(key=key, iv=iv, plaintext=None, authdata=auth_data) + _a128cbc_hs256_encrypt(key=key, iv=iv, plaintext=b'', authdata=auth_data) + + # authdata not specified + _a128cbc_hs256_encrypt(key=key, iv=iv, plaintext=plain_text, authdata=None) + _a128cbc_hs256_encrypt(key=key, iv=iv, plaintext=plain_text, authdata=b'') + + def test_a128cbc_hs256_decrypt_error(self): + key = os.urandom(32) + iv = os.urandom(16) + cipher_text = os.urandom(random.randint(1024, 4096)) + auth_data = os.urandom(random.randint(128, 512)) + auth_tag = os.urandom(16) + + with self.assertRaises(ValueError): + # key not specified + _a128cbc_hs256_decrypt(key=None, iv=iv, ciphertext=cipher_text, authdata=auth_data, authtag=auth_tag) + _a128cbc_hs256_decrypt(key=b'', iv=iv, ciphertext=cipher_text, authdata=auth_data, authtag=auth_tag) + # key insufficient len + _a128cbc_hs256_decrypt(key=os.urandom(31), iv=iv, ciphertext=cipher_text, authdata=auth_data, authtag=auth_tag) + + # iv not specified + _a128cbc_hs256_decrypt(key=key, iv=None, ciphertext=cipher_text, authdata=auth_data, authtag=auth_tag) + _a128cbc_hs256_decrypt(key=key, iv=b'', ciphertext=cipher_text, authdata=auth_data, authtag=auth_tag) + # iv incorrect len + _a128cbc_hs256_decrypt(key=key, iv=os.urandom(15), ciphertext=cipher_text, authdata=auth_data, authtag=auth_tag) + _a128cbc_hs256_decrypt(key=key, iv=os.urandom(17), ciphertext=cipher_text, authdata=auth_data, authtag=auth_tag) + + # ciphertext not specified + _a128cbc_hs256_decrypt(key=key, iv=iv, ciphertext=None, authdata=auth_data, authtag=auth_tag) + _a128cbc_hs256_decrypt(key=key, iv=iv, ciphertext=b'', authdata=auth_data, authtag=auth_tag) + + # authdata not specified + _a128cbc_hs256_decrypt(key=key, iv=iv, ciphertext=cipher_text, authdata=None, authtag=auth_tag) + _a128cbc_hs256_decrypt(key=key, iv=iv, ciphertext=cipher_text, authdata=b'', authtag=auth_tag) + + # authtag not specified + _a128cbc_hs256_decrypt(key=key, iv=iv, ciphertext=cipher_text, authdata=auth_data, authtag=None) + _a128cbc_hs256_decrypt(key=key, iv=iv, ciphertext=cipher_text, authdata=auth_data, authtag=b'') + # authtag invalid len + _a128cbc_hs256_decrypt(key=key, iv=iv, ciphertext=cipher_text, authdata=auth_data, authtag=os.urandom(17)) + _a128cbc_hs256_decrypt(key=key, iv=iv, ciphertext=cipher_text, authdata=auth_data, authtag=os.urandom(15)) + + def test_private_rsakey_to_from_jwk(self): + # create a key1 export to jwk and import as key2 + key1 = _RsaKey.generate() + jwk = key1.to_jwk(include_private=True) + key2 = _RsaKey.from_jwk(jwk) + + # validate that key2 is a private key + self.assertTrue(key2.is_private_key()) + + # validate that key2 can encrypt and decrypt properly + unwrapped = os.urandom(32) + wrapped = key1.encrypt(unwrapped) + self.assertEqual(unwrapped, key2.decrypt(wrapped)) + wrapped = key2.encrypt(unwrapped) + self.assertEqual(unwrapped, key1.decrypt(wrapped)) + + # validate that key2 can sign and verify properly + data = os.urandom(random.randint(1024, 4096)) + signature = key1.sign(data) + key2.verify(signature, data) + signature = key2.sign(data) + key1.verify(signature, data) + + # validate that all numbers, both public and private are consistent + self.assertEqual(key1.kid, key2.kid) + self.assertEqual(key1.kty, key2.kty) + self.assertEqual(key1.key_ops, key2.key_ops) + self.assertEqual(key1.n, key2.n) + self.assertEqual(key1.e, key2.e) + self.assertEqual(key1.q, key2.q) + self.assertEqual(key1.p, key2.p) + self.assertEqual(key1.d, key2.d) + self.assertEqual(key1.dq, key2.dq) + self.assertEqual(key1.dp, key2.dp) + self.assertEqual(key1.qi, key2.qi) + + # validate that key2 serializes to the same jwk + self.assertEqual(json.dumps(jwk.serialize()), json.dumps(key2.to_jwk(include_private=True).serialize())) + + def test_public_rsakey_to_from_jwk(self): + # create key1 export public components and import as key2 + key1 = _RsaKey.generate() + jwk = key1.to_jwk() + key2 = _RsaKey.from_jwk(jwk) + + # validate that key2 is not a private key + self.assertFalse(key2.is_private_key()) + + # validate that key2 can encrypt properly + unwrapped = os.urandom(32) + wrapped = key2.encrypt(unwrapped) + self.assertEqual(unwrapped, key1.decrypt(wrapped)) + + # validate that key2 can verify properly + data = os.urandom(random.randint(1024, 4096)) + signature = key1.sign(data) + key2.verify(signature, data) + + # validate that all public numbers consistent + self.assertEqual(key1.kid, key2.kid) + self.assertEqual(key1.kty, key2.kty) + self.assertEqual(key1.n, key2.n) + self.assertEqual(key1.e, key2.e) + + # validate that all private numbers are not present + self.assertIsNone(key2.q) + self.assertIsNone(key2.p) + self.assertIsNone(key2.d) + self.assertIsNone(key2.dq) + self.assertIsNone(key2.dp) + self.assertIsNone(key2.qi) + + # validate that key2 serializes to the same public jwk + self.assertEqual(json.dumps(jwk.serialize()), json.dumps(key2.to_jwk().serialize())) + + def test_jws_header_to_from_compact_header(self): + head1 = _JwsHeader() + head1.alg = 'RS256' + head1.kid = str(uuid.uuid4()) + head1.at = self._random_str(random.randint(512, 1024)) + head1.ts = int(time.time()) + head1.typ = 'PoP' + + compact = head1.to_compact_header() + + head2 = _JwsHeader.from_compact_header(compact) + + # assert that all header values match + self.assertEqual(head1.alg, head2.alg) + self.assertEqual(head1.kid, head2.kid) + self.assertEqual(head1.at, head2.at) + self.assertEqual(head1.ts, head2.ts) + self.assertEqual(head1.typ, head2.typ) + + def test_jwe_header_to_from_compact_header(self): + head1 = _JweHeader() + head1.alg = 'RSA-OAEP' + head1.kid = str(uuid.uuid4()) + head1.enc = 'A128CBC-HS256' + + compact = head1.to_compact_header() + head2 = _JweHeader.from_compact_header(compact) + + # assert that all header values match + self.assertEqual(head1.alg, head2.alg) + self.assertEqual(head1.kid, head2.kid) + self.assertEqual(head1.enc, head2.enc) + + def _random_str(self, length): + return ''.join(random.choice(string.printable) for i in range(length)) + + def _assert_bytes_significantly_equal(self, b1, b2): + self.assertEqual(b1.lstrip(b'\x00'), b2.lstrip(b'\x00')) diff --git a/azure-mgmt/tests/keyvault_testcase.py b/azure-mgmt/tests/keyvault_testcase.py index 5da788478902..75398f04fcf1 100644 --- a/azure-mgmt/tests/keyvault_testcase.py +++ b/azure-mgmt/tests/keyvault_testcase.py @@ -13,7 +13,7 @@ from azure.mgmt.keyvault.models import \ (VaultCreateOrUpdateParameters, VaultProperties, Sku, AccessPolicyEntry, Permissions, KeyPermissions, SecretPermissions, SkuName, CertificatePermissions, StoragePermissions) -from azure.keyvault import KeyVaultClient, KeyVaultAuthentication, KeyVaultAuthBase, HttpBearerChallenge +from azure.keyvault import KeyVaultClient, KeyVaultAuthentication, KeyVaultAuthBase, HttpChallenge from azure.common.exceptions import ( CloudError @@ -124,8 +124,8 @@ def setUp(self): super(AzureKeyVaultTestCase, self).setUp() def mock_key_vault_auth_base(self, request): - challenge = HttpBearerChallenge(request.url, 'Bearer authorization=fake-url,resource=https://vault.azure.net') - self.set_authorization_header(request, challenge) + challenge = HttpChallenge(request.url, 'Bearer authorization=fake-url,resource=https://vault.azure.net') + security = self._get_message_security(request, challenge) return request self.fake_settings = fake_settings diff --git a/azure.sln b/azure.sln index 4dcd8618b457..30a724e541c4 100644 --- a/azure.sln +++ b/azure.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.26602.0 +VisualStudioVersion = 15.0.27004.2002 MinimumVisualStudioVersion = 10.0.40219.1 Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "azure", "azure.pyproj", "{25B2C65A-0553-4452-8907-8B5B17544E68}" EndProject @@ -33,4 +33,7 @@ Global GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {D163CF85-8B59-44F9-9FE1-2EF9A3EB26F0} + EndGlobalSection EndGlobal