-
Notifications
You must be signed in to change notification settings - Fork 7
Feature RT-66: SDS Client #22
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,3 +108,8 @@ venv.bak/ | |
|
|
||
| # PyCharm settings | ||
| .idea | ||
|
|
||
| # Certs | ||
| **/*.cert | ||
| **/*.key | ||
| **/*.pem | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| import asyncio | ||
|
|
||
|
|
||
| def async_test(f): | ||
This conversation was marked as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| A wrapper for asynchronous tests. | ||
| By default unittest will not wait for asynchronous tests to complete even if the async functions are awaited. | ||
| By annotating a test method with `@async_test` it will cause the test to wait for asynchronous activities | ||
| to complete | ||
| :param f: | ||
| :return: | ||
| """ | ||
| def wrapper(*args, **kwargs): | ||
| coro = asyncio.coroutine(f) | ||
| future = coro(*args, **kwargs) | ||
| asyncio.run(future) | ||
|
|
||
| return wrapper | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ verify_ssl = true | |
| integration-adaptors-common = {editable = true,path = "./../common"} | ||
| requests = "*" | ||
| tornado = "*" | ||
| ldap3 = "*" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. General question not specific to this PR, but should we be specifying versions (ie the major version)? |
||
|
|
||
| [requires] | ||
| python_version = "3.7" | ||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
|
|
||
|
|
||
| class RoutingException(Exception): | ||
| """ | ||
| A RoutingException is thrown when an issue arises with the LDAP response, | ||
| particularly when no response is returned from the LDAP query. | ||
| """ | ||
|
|
||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| import asyncio | ||
| import logging | ||
| import ldap3 | ||
| from mhs.routing import routing_exception | ||
| import ldap3.core.exceptions as ldap_exceptions | ||
| from typing import Dict, List | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| NHS_SERVICES_BASE = "ou=services, o=nhs" | ||
|
|
||
| MHS_OBJECT_CLASS = "nhsMhs" | ||
| AS_OBJECT_CLASS = "nhsAs" | ||
| MHS_PARTY_KEY = 'nhsMHSPartyKey' | ||
|
|
||
| mhs_attributes = [ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How did you find out what ldap queries to do? Is there some documentation, or did you just make some ldap queries?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nothing particularly special in terms of ldap queries, just filtering on the interaction id, object code and ods code/org code as suggested here: |
||
| 'nhsEPInteractionType', 'nhsIDCode', 'nhsMhsCPAId', 'nhsMHSEndPoint', 'nhsMhsFQDN', | ||
| 'nhsMHsIN', 'nhsMHSIsAuthenticated', 'nhsMHSPartyKey', 'nhsMHsSN', 'nhsMhsSvcIA', 'nhsProductKey', | ||
| 'uniqueIdentifier', 'nhsMHSAckRequested', 'nhsMHSActor', 'nhsMHSDuplicateElimination', | ||
| 'nhsMHSPersistDuration', 'nhsMHSRetries', 'nhsMHSRetryInterval', 'nhsMHSSyncReplyMode' | ||
| ] | ||
|
|
||
|
|
||
| class SDSClient: | ||
|
|
||
| def __init__(self, sds_connection: ldap3.Connection, timeout: int = 3): | ||
| """ | ||
| :param sds_connection: takes an ldap connection to the sds server | ||
| """ | ||
| if not sds_connection: | ||
| raise ValueError('sds_connection must not be null') | ||
|
|
||
| self.connection = sds_connection | ||
| self.timeout = timeout | ||
|
|
||
| async def get_mhs_details(self, ods_code: str, interaction_id: str) -> Dict: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another point:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, I'd noticed this when looking at type hints, I refrained from using it here since the value type can be more than 1 type but I think in the future we should probably include the key/value types
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case, you could use |
||
| """ | ||
| Returns the mhs details for the given org code and interaction id | ||
| :return: Dictionary of the attributes of the mhs associated with the given parameters | ||
| """ | ||
|
|
||
| accredited_system_lookup = await self._accredited_system_lookup(ods_code, interaction_id) | ||
|
|
||
| if not accredited_system_lookup: | ||
| logger.error(f"Failed to find accredited system details for ods code : {ods_code} and interaction id: " | ||
| f"{interaction_id}") | ||
| raise routing_exception.RoutingException('No response from accredited system lookup') | ||
This conversation was marked as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if len(accredited_system_lookup) > 1: | ||
| logger.warning(f"More than one accredited system details returned on inputs: " | ||
| f"ods: {ods_code} - interaction: {interaction_id}") | ||
|
|
||
| # As per the spec exactly one result should be returned | ||
| response = accredited_system_lookup[0] | ||
| party_key = response['attributes'][MHS_PARTY_KEY] | ||
|
|
||
| details = await self._mhs_details_lookup(party_key, interaction_id) | ||
|
|
||
| if not details: | ||
| logger.error(f'No mhs details returned for party key: {party_key} and interaction id : {interaction_id}') | ||
| raise routing_exception.RoutingException(f'No mhs details returned for party key: ' | ||
| f'{party_key} and interaction id : {interaction_id}') | ||
| if len(details) > 1: | ||
| logger.warning(f"More than one mhs details returned on inputs: " | ||
| f"ods: {ods_code} - interaction: {interaction_id}") | ||
| return details[0]['attributes'] | ||
|
|
||
| async def _accredited_system_lookup(self, ods_code: str, interaction_id: str) -> List: | ||
| """ | ||
| Used to find an accredited system, the result contains the nhsMhsPartyKey. | ||
| This can then be used to find an MHS endpoint | ||
| :return: endpoint details - filtered to only contain nhsMHSPartyKey | ||
| """ | ||
|
|
||
| search_filter = f"(&(nhsIDCode={ods_code}) (objectClass={AS_OBJECT_CLASS}) (nhsAsSvcIA={interaction_id}))" | ||
|
|
||
| message_id = self.connection.search(search_base=NHS_SERVICES_BASE, | ||
| search_filter=search_filter, | ||
| attributes=MHS_PARTY_KEY) | ||
| logger.info(f'Message id - {message_id} - for query: ods code - {ods_code} ' | ||
| f': interaction id - {interaction_id}') | ||
|
|
||
| response = await self._get_query_result(message_id) | ||
| logger.info(f'Found accredited supplier details for message_id: {message_id}') | ||
|
|
||
| return response | ||
|
|
||
| async def _mhs_details_lookup(self, party_key: str, interaction_id: str) -> List: | ||
| """ | ||
| Given a party key and an interaction id, this will return an object containing the attributes of that party key, | ||
| including the endpoint address | ||
| :return: all the endpoint details | ||
| """ | ||
| search_filter = f"(&(objectClass={MHS_OBJECT_CLASS})" \ | ||
| f" ({MHS_PARTY_KEY}={party_key})" \ | ||
| f" (nhsMhsSvcIA={interaction_id}))" | ||
| message_id = self.connection.search(search_base=NHS_SERVICES_BASE, | ||
| search_filter=search_filter, | ||
| attributes=mhs_attributes) | ||
|
|
||
| logger.info(f'Message id - {message_id} - for query: party key - {party_key} ' | ||
| f': interaction id - {interaction_id}') | ||
|
|
||
| response = await self._get_query_result(message_id) | ||
| logger.info(f'Found mhs details for message_id: {message_id}') | ||
|
|
||
| return response | ||
|
|
||
| async def _get_query_result(self, message_id: int) -> List: | ||
| loop = asyncio.get_event_loop() | ||
| response = [] | ||
| try: | ||
| response, result = await loop.run_in_executor(None, self.connection.get_response, message_id, self.timeout) | ||
| except ldap_exceptions.LDAPResponseTimeoutError: | ||
| logger.error(f'LDAP query timed out for message id: {message_id}') | ||
|
|
||
| return response | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import ssl | ||
| import ldap3 | ||
|
|
||
|
|
||
| def build_sds_connection(ldap_address: str) -> ldap3.Connection: | ||
| """ | ||
| Given an ldap service address this will return a ldap3 connection object | ||
| """ | ||
| server = ldap3.Server(ldap_address) | ||
| connection = ldap3.Connection(server, auto_bind=True, client_strategy=ldap3.REUSABLE) | ||
| return connection | ||
|
|
||
|
|
||
| def build_sds_connection_tls(ldap_address: str, | ||
| private_key_path: str, | ||
| local_cert_path: str, | ||
| ca_certs_file: str | ||
| ) -> ldap3.Connection: | ||
| """ | ||
| This will return a connection object for the given ip along with loading the given certification files | ||
| :return: Connection object using the given cert files | ||
| """ | ||
|
|
||
| load_tls = ldap3.Tls(local_private_key_file=private_key_path, | ||
| local_certificate_file=local_cert_path, | ||
| validate=ssl.CERT_REQUIRED, version=ssl.PROTOCOL_TLSv1, | ||
| ca_certs_file=ca_certs_file) | ||
|
|
||
| server = ldap3.Server(ldap_address, use_ssl=True, tls=load_tls) | ||
| connection = ldap3.Connection(server, auto_bind=True, client_strategy=ldap3.REUSABLE) | ||
| return connection |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| import mhs.routing.sds as sds | ||
|
|
||
|
|
||
| class MHSAttributeLookupHandler: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of this class? It looks like it's an incomplete caching layer to me.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is exactly what it is, in the next PR this handler manages the cache checking and calls to the sds client depending on what the cache returns |
||
|
|
||
| def __init__(self, client: sds.SDSClient): | ||
| if not client: | ||
| raise ValueError('sds client required') | ||
| self.sds_client = client | ||
|
|
||
| async def retrieve_mhs_attributes(self, org_code, interaction_id): | ||
| endpoint_details = await self.sds_client.get_mhs_details(org_code, interaction_id) | ||
| return endpoint_details | ||
Uh oh!
There was an error while loading. Please reload this page.