Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

66 changes: 31 additions & 35 deletions sdk/keyvault/azure-keyvault-certificates/tests/test_parse_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -------------------------------------
import pytest
import functools
from azure.keyvault.certificates import CertificateClient, CertificatePolicy, parse_key_vault_certificate_id
from devtools_testutils import ResourceGroupPreparer, KeyVaultPreparer

from _shared.preparer import KeyVaultClientPreparer as _KeyVaultClientPreparer
from _shared.preparer import KeyVaultClientPreparer
from _shared.test_case import KeyVaultTestCase

# pre-apply the client_cls positional argument so it needn't be explicitly passed below
KeyVaultClientPreparer = functools.partial(_KeyVaultClientPreparer, CertificateClient)


class TestParseId(KeyVaultTestCase):
@ResourceGroupPreparer(random_name_enabled=True)
@KeyVaultPreparer()
@KeyVaultClientPreparer()
@KeyVaultClientPreparer(CertificateClient)
def test_parse_certificate_id_with_version(self, client):
cert_name = self.get_resource_name("cert")
# create certificate
Expand All @@ -32,31 +27,32 @@ def test_parse_certificate_id_with_version(self, client):
print(parsed_certificate_id.version)
print(parsed_certificate_id.source_id)
# [END parse_key_vault_certificate_id]
self.assertEqual(parsed_certificate_id.name, cert_name)
self.assertEqual(parsed_certificate_id.vault_url, client.vault_url)
self.assertEqual(parsed_certificate_id.version, cert.properties.version)
self.assertEqual(parsed_certificate_id.source_id, cert.id)

def test_parse_certificate_id_with_pending_version(self):
source_id = "https://keyvault-name.vault.azure.net/certificates/certificate-name/pending"
parsed_certificate_id = parse_key_vault_certificate_id(source_id)

self.assertEqual(parsed_certificate_id.name, "certificate-name")
self.assertEqual(parsed_certificate_id.vault_url, "https://keyvault-name.vault.azure.net")
self.assertEqual(parsed_certificate_id.version, "pending")
self.assertEqual(
parsed_certificate_id.source_id,
"https://keyvault-name.vault.azure.net/certificates/certificate-name/pending",
)

def test_parse_deleted_certificate_id(self):
source_id = "https://keyvault-name.vault.azure.net/deletedcertificates/deleted-certificate"
parsed_certificate_id = parse_key_vault_certificate_id(source_id)

self.assertEqual(parsed_certificate_id.name, "deleted-certificate")
self.assertEqual(parsed_certificate_id.vault_url, "https://keyvault-name.vault.azure.net")
self.assertIsNone(parsed_certificate_id.version)
self.assertEqual(
parsed_certificate_id.source_id,
"https://keyvault-name.vault.azure.net/deletedcertificates/deleted-certificate",
)
assert parsed_certificate_id.name == cert_name
assert parsed_certificate_id.vault_url == client.vault_url
assert parsed_certificate_id.version == cert.properties.version
assert parsed_certificate_id.source_id == cert.id


def test_parse_certificate_id_with_pending_version():
source_id = "https://keyvault-name.vault.azure.net/certificates/certificate-name/pending"
parsed_certificate_id = parse_key_vault_certificate_id(source_id)

assert parsed_certificate_id.name == "certificate-name"
assert parsed_certificate_id.vault_url == "https://keyvault-name.vault.azure.net"
assert parsed_certificate_id.version == "pending"
assert (
parsed_certificate_id.source_id == "https://keyvault-name.vault.azure.net/certificates/certificate-name/pending"
)


def test_parse_deleted_certificate_id():
source_id = "https://keyvault-name.vault.azure.net/deletedcertificates/deleted-certificate"
parsed_certificate_id = parse_key_vault_certificate_id(source_id)

assert parsed_certificate_id.name == "deleted-certificate"
assert parsed_certificate_id.vault_url == "https://keyvault-name.vault.azure.net"
assert parsed_certificate_id.version is None
assert (
parsed_certificate_id.source_id
== "https://keyvault-name.vault.azure.net/deletedcertificates/deleted-certificate"
)
3 changes: 3 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Release History

## 4.3.1 (Unreleased)
### Added
- Added method `parse_key_vault_key_id` that parses out a full ID returned by Key Vault, so users can easily
access the key's `name`, `vault_url`, and `version`.


## 4.3.0 (2020-10-06)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# Licensed under the MIT License.
# -------------------------------------
from ._enums import KeyCurveName, KeyOperation, KeyType
from ._parse_id import parse_key_vault_key_id
from ._shared.client_base import ApiVersion
from ._shared import KeyVaultResourceId
from ._models import DeletedKey, JsonWebKey, KeyProperties, KeyVaultKey
from ._client import KeyClient

Expand All @@ -17,6 +19,8 @@
"KeyType",
"DeletedKey",
"KeyProperties",
"parse_key_vault_key_id",
"KeyVaultResourceId"
]

from ._version import VERSION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# -------------------------------------
from collections import namedtuple
from ._shared import parse_vault_id
from ._shared import parse_key_vault_id
from ._generated.v7_1.models import JsonWebKey as _JsonWebKey

try:
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, key_id, attributes=None, **kwargs):
# type: (str, Optional[_models.KeyAttributes], **Any) -> None
self._attributes = attributes
self._id = key_id
self._vault_id = parse_vault_id(key_id)
self._vault_id = parse_key_vault_id(key_id)
self._managed = kwargs.get("managed", None)
self._tags = kwargs.get("tags", None)

Expand Down
29 changes: 29 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_parse_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

from ._shared import parse_key_vault_id, KeyVaultResourceId


def parse_key_vault_key_id(source_id):
# type: (str) -> KeyVaultResourceId
"""Parses a key's full ID into a class with parsed contents as attributes.

:param str source_id: the full original identifier of a key
:returns: Returns a parsed key ID as a :class:`KeyVaultResourceId`
:rtype: ~azure.keyvault.keys.KeyVaultResourceId
:raises: ValueError
Example:
.. literalinclude:: ../tests/test_parse_id.py
:start-after: [START parse_key_vault_key_id]
:end-before: [END parse_key_vault_key_id]
:language: python
:caption: Parse a key's ID
:dedent: 8
"""
parsed_id = parse_key_vault_id(source_id)

return KeyVaultResourceId(
name=parsed_id.name, source_id=parsed_id.source_id, vault_url=parsed_id.vault_url, version=parsed_id.version
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from collections import namedtuple

try:
import urllib.parse as parse
except ImportError:
# pylint:disable=import-error
import urlparse as parse # type: ignore

from typing import TYPE_CHECKING
from .challenge_auth_policy import ChallengeAuthPolicy, ChallengeAuthPolicyBase
from .client_base import KeyVaultClientBase
from .http_challenge import HttpChallenge
from . import http_challenge_cache as HttpChallengeCache

if TYPE_CHECKING:
# pylint: disable=unused-import
from typing import Optional


__all__ = [
"ChallengeAuthPolicy",
Expand All @@ -24,25 +27,45 @@
"KeyVaultClientBase",
]

_VaultId = namedtuple("VaultId", ["vault_url", "collection", "name", "version"])
class KeyVaultResourceId():
"""Represents a Key Vault identifier and its parsed contents.

:param str source_id: The complete identifier received from Key Vault
:param str vault_url: The vault URL
:param str name: The name extracted from the ID
:param str version: The version extracted from the ID
"""

def __init__(
self,
source_id, # type: str
vault_url, # type: str
name, # type: str
version=None # type: Optional[str]
):
self.source_id = source_id
self.vault_url = vault_url
self.name = name
self.version = version


def parse_vault_id(url):
def parse_key_vault_id(source_id):
# type: (str) -> KeyVaultResourceId
try:
parsed_uri = parse.urlparse(url)
parsed_uri = parse.urlparse(source_id)
except Exception: # pylint: disable=broad-except
raise ValueError("'{}' is not not a valid url".format(url))
raise ValueError("'{}' is not not a valid url".format(source_id))
if not (parsed_uri.scheme and parsed_uri.hostname):
raise ValueError("'{}' is not not a valid url".format(url))
raise ValueError("'{}' is not not a valid url".format(source_id))

path = list(filter(None, parsed_uri.path.split("/")))

if len(path) < 2 or len(path) > 3:
raise ValueError("'{}' is not not a valid vault url".format(url))
raise ValueError("'{}' is not not a valid vault url".format(source_id))

return _VaultId(
return KeyVaultResourceId(
source_id=source_id,
vault_url="{}://{}".format(parsed_uri.scheme, parsed_uri.hostname),
collection=path[0],
name=path[1],
version=path[2] if len(path) == 3 else None,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ._providers import get_local_cryptography_provider, NoLocalCryptography
from .. import KeyOperation
from .._models import KeyVaultKey
from .._shared import KeyVaultClientBase, parse_vault_id
from .._shared import KeyVaultClientBase, parse_key_vault_id

if TYPE_CHECKING:
# pylint:disable=unused-import
Expand Down Expand Up @@ -53,10 +53,10 @@ def __init__(self, key, credential, **kwargs):

if isinstance(key, KeyVaultKey):
self._key = key
self._key_id = parse_vault_id(key.id)
self._key_id = parse_key_vault_id(key.id)
elif isinstance(key, six.string_types):
self._key = None
self._key_id = parse_vault_id(key)
self._key_id = parse_key_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")
Expand All @@ -76,7 +76,7 @@ def key_id(self):

:rtype: str
"""
return "/".join(self._key_id)
return self._key_id.source_id

@distributed_trace
def _initialize(self, **kwargs):
Expand Down Expand Up @@ -242,7 +242,7 @@ def unwrap_key(self, algorithm, encrypted_key, **kwargs):
parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=encrypted_key),
**kwargs
)
return UnwrapResult(key_id=self._key_id, algorithm=algorithm, key=operation_result.result)
return UnwrapResult(key_id=self.key_id, algorithm=algorithm, key=operation_result.result)

@distributed_trace
def sign(self, algorithm, digest, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .._providers import get_local_cryptography_provider, NoLocalCryptography
from ... import KeyOperation
from ..._models import KeyVaultKey
from ..._shared import AsyncKeyVaultClientBase, parse_vault_id
from ..._shared import AsyncKeyVaultClientBase, parse_key_vault_id

if TYPE_CHECKING:
# pylint:disable=unused-import
Expand Down Expand Up @@ -50,10 +50,10 @@ class CryptographyClient(AsyncKeyVaultClientBase):
def __init__(self, key: "Union[KeyVaultKey, str]", credential: "AsyncTokenCredential", **kwargs: "Any") -> None:
if isinstance(key, KeyVaultKey):
self._key = key
self._key_id = parse_vault_id(key.id)
self._key_id = parse_key_vault_id(key.id)
elif isinstance(key, str):
self._key = None
self._key_id = parse_vault_id(key)
self._key_id = parse_key_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")
Expand All @@ -72,7 +72,7 @@ def key_id(self) -> str:

:rtype: str
"""
return "/".join(self._key_id)
return self._key_id.source_id

@distributed_trace_async
async def _initialize(self, **kwargs):
Expand Down
Loading