Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
### Fixed
- Correct typing for paging methods

### Added
- Added method `parse_key_vault_certificate_id` that parses out a full ID returned by key vault, so users
can easily access the certificate's `name`, `vault_url`, and `version`.


## 4.2.0 (2020-08-11)
### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
LifetimeAction,
KeyVaultCertificate
)
from ._parse_id import parse_key_vault_certificate_id
from ._shared.client_base import ApiVersion
from ._shared import ParsedId

__all__ = [
"ApiVersion",
Expand All @@ -47,7 +49,9 @@
"CertificateContentType",
"WellKnownIssuerNames",
"CertificateIssuer",
"IssuerProperties"
"IssuerProperties",
"parse_key_vault_certificate_id",
"ParsedId",
]

from ._version import VERSION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ------------------------------------

# pylint: disable=too-many-lines,too-many-public-methods
from ._shared import parse_vault_id
from ._shared import parse_key_vault_identifier
from ._generated.v7_1 import models
from ._enums import(
CertificatePolicyAction,
Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(self, **kwargs):
# type: (**Any) -> None
self._attributes = kwargs.pop("attributes", None)
self._id = kwargs.pop("cert_id", None)
self._vault_id = parse_vault_id(self._id)
self._vault_id = parse_key_vault_identifier(self._id)
self._x509_thumbprint = kwargs.pop("x509_thumbprint", None)
self._tags = kwargs.pop("tags", None)

Expand Down Expand Up @@ -430,7 +430,7 @@ def __init__(
):
# type: (...) -> None
self._id = cert_operation_id
self._vault_id = parse_vault_id(cert_operation_id)
self._vault_id = parse_key_vault_identifier(cert_operation_id)
self._issuer_name = issuer_name
self._certificate_type = certificate_type
self._certificate_transparency = certificate_transparency
Expand Down Expand Up @@ -1058,7 +1058,7 @@ class IssuerProperties(object):
def __init__(self, provider=None, **kwargs):
# type: (Optional[str], **Any) -> None
self._id = kwargs.pop("issuer_id", None)
self._vault_id = parse_vault_id(self._id)
self._vault_id = parse_key_vault_identifier(self._id)
self._provider = provider

def __repr__(self):
Expand Down Expand Up @@ -1120,7 +1120,7 @@ def __init__(
self._organization_id = organization_id
self._admin_contacts = admin_contacts
self._id = kwargs.pop("issuer_id", None)
self._vault_id = parse_vault_id(self._id)
self._vault_id = parse_key_vault_identifier(self._id)

def __repr__(self):
# type () -> str
Expand Down Expand Up @@ -1158,7 +1158,8 @@ def id(self):
def name(self):
# type: () -> str
# Issuer name is listed under version under vault_id
# This is because the id we pass to parse_vault_id has an extra segment, so where most cases the version of
# This is because the id we pass to parse_key_vault_identifier has an extra segment,
# so where most cases the version of
# The general pattern is certificates/name/version, but here we have certificates/issuers/name/version
# Issuers are not versioned.
""":rtype: str"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

from ._shared import parse_key_vault_identifier, ParsedId


def parse_key_vault_certificate_id(source_id):
# type: (str) -> ParsedId
"""Parses a full certificate's ID into a class.

:param str source_id: the full original identifier of a certificate
:returns: Returns a parsed certificate id
:rtype: ~azure.keyvault.certificates.ParsedId
:raises: ValueError

Example:
.. literalinclude:: ../tests/test_parse_id.py
:start-after: [START parse_key_vault_certificate_id]
:end-before: [END parse_key_vault_certificate_id]
:language: python
:caption: Parse a certificate's ID
:dedent: 8
"""
parsed_id = parse_key_vault_identifier(source_id)

return ParsedId(
name=parsed_id.name,
source_id=parsed_id.source_id,
vault_url=parsed_id.vault_url,
version=parsed_id.version
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from collections import namedtuple

try:
import urllib.parse as parse
except ImportError:
Expand All @@ -24,25 +22,45 @@
"KeyVaultClientBase",
]

_VaultId = namedtuple("VaultId", ["vault_url", "collection", "name", "version"])
class ParsedId():
"""Represents a key vault identifier and its parsed contents.

:param str source_id: The originally received complete identifier
:param str vault_url: The vault URL
:param str name: The name extracted from the id
:param str version: The version extracted from the id
"""

def __init__(
self,
source_id, # type: str
vault_url, # type: str
name, # type: str
version=None # type: Optional[str]
):
self.source_id = source_id
self.vault_url = vault_url
self.name = name
self.version = version


def parse_vault_id(url):
def parse_key_vault_identifier(source_id):
# type: (str) -> ParsedId
try:
parsed_uri = parse.urlparse(url)
parsed_uri = parse.urlparse(source_id)
except Exception: # pylint: disable=broad-except
raise ValueError("'{}' is not not a valid url".format(url))
raise ValueError("'{}' is not not a valid ID".format(source_id))
if not (parsed_uri.scheme and parsed_uri.hostname):
raise ValueError("'{}' is not not a valid url".format(url))
raise ValueError("'{}' is not not a valid ID".format(source_id))

path = list(filter(None, parsed_uri.path.split("/")))

if len(path) < 2 or len(path) > 3:
raise ValueError("'{}' is not not a valid vault url".format(url))
raise ValueError("'{}' is not not a valid vault ID".format(source_id))

return _VaultId(
return ParsedId(
source_id=source_id,
vault_url="{}://{}".format(parsed_uri.scheme, parsed_uri.hostname),
collection=path[0],
name=path[1],
version=path[2] if len(path) == 3 else None,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
CertificateIssuer,
IssuerProperties,
)
from azure.keyvault.certificates._shared import parse_vault_id
from azure.keyvault.certificates._shared import parse_key_vault_identifier
from devtools_testutils import ResourceGroupPreparer, KeyVaultPreparer

from _shared.preparer import KeyVaultClientPreparer as _KeyVaultClientPreparer
Expand Down Expand Up @@ -83,7 +83,7 @@ def _validate_certificate_operation(self, pending_cert_operation, vault, cert_na
self.assertIsNotNone(pending_cert_operation)
self.assertIsNotNone(pending_cert_operation.csr)
self.assertEqual(original_cert_policy.issuer_name, pending_cert_operation.issuer_name)
pending_id = parse_vault_id(pending_cert_operation.id)
pending_id = parse_key_vault_identifier(pending_cert_operation.id)
self.assertEqual(pending_id.vault_url.strip("/"), vault.strip("/"))
self.assertEqual(pending_id.name, cert_name)

Expand Down Expand Up @@ -258,9 +258,10 @@ def test_list(self, client, **kwargs):
error_count = 0
try:
cert_bundle = self._import_common_certificate(client=client, cert_name=cert_name)
parsed_id = parse_vault_id(url=cert_bundle.id)
cid = parsed_id.vault_url + "/" + parsed_id.collection + "/" + parsed_id.name
expected[cid.strip("/")] = cert_bundle
# going to remove the id from the last '/' onwards. This is because list
# properties of certificates doesn't return the version in the id
cid = "/".join(cert_bundle.id.split("/")[:-1])
expected[cid] = cert_bundle
except Exception as ex:
if hasattr(ex, "message") and "Throttled" in ex.message:
error_count += 1
Expand All @@ -287,9 +288,7 @@ def test_list_certificate_versions(self, client, **kwargs):
error_count = 0
try:
cert_bundle = self._import_common_certificate(client=client, cert_name=cert_name)
parsed_id = parse_vault_id(url=cert_bundle.id)
cid = parsed_id.vault_url + "/" + parsed_id.collection + "/" + parsed_id.name + "/" + parsed_id.version
expected[cid.strip("/")] = cert_bundle
expected[cert_bundle.id] = cert_bundle
except Exception as ex:
if hasattr(ex, "message") and "Throttled" in ex.message:
error_count += 1
Expand Down Expand Up @@ -355,7 +354,7 @@ def test_recover_and_purge(self, client, **kwargs):
client.begin_delete_certificate(certificate_name=cert_name).wait()

# validate all our deleted certificates are returned by list_deleted_certificates
deleted = [parse_vault_id(url=c.id).name for c in client.list_deleted_certificates()]
deleted = [parse_key_vault_identifier(source_id=c.id).name for c in client.list_deleted_certificates()]
self.assertTrue(all(c in deleted for c in certs.keys()))

# recover select certificates
Expand All @@ -370,7 +369,7 @@ def test_recover_and_purge(self, client, **kwargs):
time.sleep(50)

# validate none of our deleted certificates are returned by list_deleted_certificates
deleted = [parse_vault_id(url=c.id).name for c in client.list_deleted_certificates()]
deleted = [parse_key_vault_identifier(source_id=c.id).name for c in client.list_deleted_certificates()]
self.assertTrue(not any(c in deleted for c in certs.keys()))

# validate the recovered certificates
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
LifetimeAction,
CertificateIssuer,
IssuerProperties,
parse_key_vault_certificate_id,
)
from azure.keyvault.certificates.aio import CertificateClient
from azure.keyvault.certificates._shared import parse_vault_id
from devtools_testutils import ResourceGroupPreparer, KeyVaultPreparer
import pytest

Expand Down Expand Up @@ -77,7 +77,7 @@ def _validate_certificate_operation(self, pending_cert_operation, vault, cert_na
self.assertIsNotNone(pending_cert_operation)
self.assertIsNotNone(pending_cert_operation.csr)
self.assertEqual(original_cert_policy.issuer_name, pending_cert_operation.issuer_name)
pending_id = parse_vault_id(pending_cert_operation.id)
pending_id = parse_key_vault_certificate_id(pending_cert_operation.id)
self.assertEqual(pending_id.vault_url.strip("/"), vault.strip("/"))
self.assertEqual(pending_id.name, cert_name)

Expand Down Expand Up @@ -253,9 +253,10 @@ async def test_list(self, client, **kwargs):
error_count = 0
try:
cert_bundle = await self._import_common_certificate(client=client, cert_name=cert_name)
parsed_id = parse_vault_id(url=cert_bundle.id)
cid = parsed_id.vault_url + "/" + parsed_id.collection + "/" + parsed_id.name
expected[cid.strip("/")] = cert_bundle
# going to remove the id from the last '/' onwards. This is because list
# properties of certificates doesn't return the version in the id
cid = "/".join(cert_bundle.id.split("/")[:-1])
expected[cid] = cert_bundle
except Exception as ex:
if hasattr(ex, "message") and "Throttled" in ex.message:
error_count += 1
Expand All @@ -282,9 +283,7 @@ async def test_list_certificate_versions(self, client, **kwargs):
error_count = 0
try:
cert_bundle = await self._import_common_certificate(client=client, cert_name=cert_name)
parsed_id = parse_vault_id(url=cert_bundle.id)
cid = parsed_id.vault_url + "/" + parsed_id.collection + "/" + parsed_id.name + "/" + parsed_id.version
expected[cid.strip("/")] = cert_bundle
expected[cert_bundle.id.strip("/")] = cert_bundle
except Exception as ex:
if hasattr(ex, "message") and "Throttled" in ex.message:
error_count += 1
Expand Down Expand Up @@ -355,7 +354,7 @@ async def test_recover_and_purge(self, client, **kwargs):
deleted_certificates = client.list_deleted_certificates()
deleted = []
async for c in deleted_certificates:
deleted.append(parse_vault_id(url=c.id).name)
deleted.append(parse_key_vault_certificate_id(source_id=c.id).name)
self.assertTrue(all(c in deleted for c in certs.keys()))

# recover select certificates
Expand All @@ -373,7 +372,7 @@ async def test_recover_and_purge(self, client, **kwargs):
deleted_certificates = client.list_deleted_certificates()
deleted = []
async for c in deleted_certificates:
deleted.append(parse_vault_id(url=c.id).name)
deleted.append(parse_key_vault_certificate_id(source_id=c.id).name)
self.assertTrue(not any(c in deleted for c in certs.keys()))

# validate the recovered certificates
Expand Down
42 changes: 42 additions & 0 deletions sdk/keyvault/azure-keyvault-certificates/tests/test_parse_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# -------------------------------------
import pytest
from azure.keyvault.certificates import parse_key_vault_certificate_id
from _shared.test_case import KeyVaultTestCase


class TestParseId(KeyVaultTestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests can be module-level methods.

def test_parse_certificate_id_with_version(self):
# [START parse_key_vault_certificate_id]
source_id = "https://keyvault-name.vault.azure.net/certificates/certificate-name/version"
parsed_certificate_id = parse_key_vault_certificate_id(source_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing a literal seems odd in the snippet showing people how to use the method. This is what I imagine people would usually want to do:

cert = client.get_certificate("...")
parsed_id = parse_key_vault_certificate_id(cert.id)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that!


print(parsed_certificate_id.name)
print(parsed_certificate_id.vault_url)
print(parsed_certificate_id.version)
print(parsed_certificate_id.source_id)
# [END parse_key_vault_certificate_id]
assert parsed_certificate_id.name == "certificate-name"
assert parsed_certificate_id.vault_url == "https://keyvault-name.vault.azure.net"
assert parsed_certificate_id.version == "version"
assert parsed_certificate_id.source_id == "https://keyvault-name.vault.azure.net/certificates/certificate-name/version"

def test_parse_certificate_id_with_pending_version(self):
source_id = "https://keyvault-name.vault.azure.net/certificates/certificate-name/pending"
parsed_certificate_id = parse_key_vault_certificate_id(source_id)

assert parsed_certificate_id.name == "certificate-name"
assert parsed_certificate_id.vault_url == "https://keyvault-name.vault.azure.net"
assert parsed_certificate_id.version == "pending"
assert parsed_certificate_id.source_id == "https://keyvault-name.vault.azure.net/certificates/certificate-name/pending"

def test_parse_deleted_certificate_id(self):
source_id = "https://keyvault-name.vault.azure.net/deletedcertificates/deleted-certificate"
parsed_certificate_id = parse_key_vault_certificate_id(source_id)

assert parsed_certificate_id.name == "deleted-certificate"
assert parsed_certificate_id.vault_url == "https://keyvault-name.vault.azure.net"
assert parsed_certificate_id.version == None
assert parsed_certificate_id.source_id == "https://keyvault-name.vault.azure.net/deletedcertificates/deleted-certificate"