Skip to content
This repository was archived by the owner on Aug 5, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,8 @@ venv.bak/

# PyCharm settings
.idea

# Certs
**/*.cert
**/*.key
**/*.pem
18 changes: 18 additions & 0 deletions common/utilities/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import asyncio


def async_test(f):
"""
A wrapper for asynchronous tests.
By default unittest will not wait for asynchronous tests to complete even if the async functions are awaited.
By annotating a test method with `@async_test` it will cause the test to wait for asynchronous activities
to complete
:param f:
:return:
"""
def wrapper(*args, **kwargs):
coro = asyncio.coroutine(f)
future = coro(*args, **kwargs)
asyncio.run(future)

return wrapper
1 change: 1 addition & 0 deletions mhs-reference-implementation/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ verify_ssl = true
integration-adaptors-common = {editable = true,path = "./../common"}
requests = "*"
tornado = "*"
ldap3 = "*"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question not specific to this PR, but should we be specifying versions (ie the major version)?


[requires]
python_version = "3.7"
Expand Down
103 changes: 58 additions & 45 deletions mhs-reference-implementation/Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Empty file.
9 changes: 9 additions & 0 deletions mhs-reference-implementation/mhs/routing/routing_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@


class RoutingException(Exception):
"""
A RoutingException is thrown when an issue arises with the LDAP response,
particularly when no response is returned from the LDAP query.
"""

pass
117 changes: 117 additions & 0 deletions mhs-reference-implementation/mhs/routing/sds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import asyncio
import logging
import ldap3
from mhs.routing import routing_exception
import ldap3.core.exceptions as ldap_exceptions
from typing import Dict, List

logger = logging.getLogger(__name__)

NHS_SERVICES_BASE = "ou=services, o=nhs"

MHS_OBJECT_CLASS = "nhsMhs"
AS_OBJECT_CLASS = "nhsAs"
MHS_PARTY_KEY = 'nhsMHSPartyKey'

mhs_attributes = [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you find out what ldap queries to do? Is there some documentation, or did you just make some ldap queries?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing particularly special in terms of ldap queries, just filtering on the interaction id, object code and ods code/org code as suggested here:
https://nhsconnect.github.io/gpconnect/integration_spine_directory_service.html

'nhsEPInteractionType', 'nhsIDCode', 'nhsMhsCPAId', 'nhsMHSEndPoint', 'nhsMhsFQDN',
'nhsMHsIN', 'nhsMHSIsAuthenticated', 'nhsMHSPartyKey', 'nhsMHsSN', 'nhsMhsSvcIA', 'nhsProductKey',
'uniqueIdentifier', 'nhsMHSAckRequested', 'nhsMHSActor', 'nhsMHSDuplicateElimination',
'nhsMHSPersistDuration', 'nhsMHSRetries', 'nhsMHSRetryInterval', 'nhsMHSSyncReplyMode'
]


class SDSClient:

def __init__(self, sds_connection: ldap3.Connection, timeout: int = 3):
"""
:param sds_connection: takes an ldap connection to the sds server
"""
if not sds_connection:
raise ValueError('sds_connection must not be null')

self.connection = sds_connection
self.timeout = timeout

async def get_mhs_details(self, ods_code: str, interaction_id: str) -> Dict:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another point: Dict can be made more specific like eg Dict[keyType, valueType]. As can other container types eg List[str]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I'd noticed this when looking at type hints, I refrained from using it here since the value type can be more than 1 type but I think in the future we should probably include the key/value types

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, you could use Any, or Union[someType, someOtherType], though that could get long-winded.

"""
Returns the mhs details for the given org code and interaction id
:return: Dictionary of the attributes of the mhs associated with the given parameters
"""

accredited_system_lookup = await self._accredited_system_lookup(ods_code, interaction_id)

if not accredited_system_lookup:
logger.error(f"Failed to find accredited system details for ods code : {ods_code} and interaction id: "
f"{interaction_id}")
raise routing_exception.RoutingException('No response from accredited system lookup')

if len(accredited_system_lookup) > 1:
logger.warning(f"More than one accredited system details returned on inputs: "
f"ods: {ods_code} - interaction: {interaction_id}")

# As per the spec exactly one result should be returned
response = accredited_system_lookup[0]
party_key = response['attributes'][MHS_PARTY_KEY]

details = await self._mhs_details_lookup(party_key, interaction_id)

if not details:
logger.error(f'No mhs details returned for party key: {party_key} and interaction id : {interaction_id}')
raise routing_exception.RoutingException(f'No mhs details returned for party key: '
f'{party_key} and interaction id : {interaction_id}')
if len(details) > 1:
logger.warning(f"More than one mhs details returned on inputs: "
f"ods: {ods_code} - interaction: {interaction_id}")
return details[0]['attributes']

async def _accredited_system_lookup(self, ods_code: str, interaction_id: str) -> List:
"""
Used to find an accredited system, the result contains the nhsMhsPartyKey.
This can then be used to find an MHS endpoint
:return: endpoint details - filtered to only contain nhsMHSPartyKey
"""

search_filter = f"(&(nhsIDCode={ods_code}) (objectClass={AS_OBJECT_CLASS}) (nhsAsSvcIA={interaction_id}))"

message_id = self.connection.search(search_base=NHS_SERVICES_BASE,
search_filter=search_filter,
attributes=MHS_PARTY_KEY)
logger.info(f'Message id - {message_id} - for query: ods code - {ods_code} '
f': interaction id - {interaction_id}')

response = await self._get_query_result(message_id)
logger.info(f'Found accredited supplier details for message_id: {message_id}')

return response

async def _mhs_details_lookup(self, party_key: str, interaction_id: str) -> List:
"""
Given a party key and an interaction id, this will return an object containing the attributes of that party key,
including the endpoint address
:return: all the endpoint details
"""
search_filter = f"(&(objectClass={MHS_OBJECT_CLASS})" \
f" ({MHS_PARTY_KEY}={party_key})" \
f" (nhsMhsSvcIA={interaction_id}))"
message_id = self.connection.search(search_base=NHS_SERVICES_BASE,
search_filter=search_filter,
attributes=mhs_attributes)

logger.info(f'Message id - {message_id} - for query: party key - {party_key} '
f': interaction id - {interaction_id}')

response = await self._get_query_result(message_id)
logger.info(f'Found mhs details for message_id: {message_id}')

return response

async def _get_query_result(self, message_id: int) -> List:
loop = asyncio.get_event_loop()
response = []
try:
response, result = await loop.run_in_executor(None, self.connection.get_response, message_id, self.timeout)
except ldap_exceptions.LDAPResponseTimeoutError:
logger.error(f'LDAP query timed out for message id: {message_id}')

return response
31 changes: 31 additions & 0 deletions mhs-reference-implementation/mhs/routing/sds_connection_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import ssl
import ldap3


def build_sds_connection(ldap_address: str) -> ldap3.Connection:
"""
Given an ldap service address this will return a ldap3 connection object
"""
server = ldap3.Server(ldap_address)
connection = ldap3.Connection(server, auto_bind=True, client_strategy=ldap3.REUSABLE)
return connection


def build_sds_connection_tls(ldap_address: str,
private_key_path: str,
local_cert_path: str,
ca_certs_file: str
) -> ldap3.Connection:
"""
This will return a connection object for the given ip along with loading the given certification files
:return: Connection object using the given cert files
"""

load_tls = ldap3.Tls(local_private_key_file=private_key_path,
local_certificate_file=local_cert_path,
validate=ssl.CERT_REQUIRED, version=ssl.PROTOCOL_TLSv1,
ca_certs_file=ca_certs_file)

server = ldap3.Server(ldap_address, use_ssl=True, tls=load_tls)
connection = ldap3.Connection(server, auto_bind=True, client_strategy=ldap3.REUSABLE)
return connection
13 changes: 13 additions & 0 deletions mhs-reference-implementation/mhs/routing/sds_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import mhs.routing.sds as sds


class MHSAttributeLookupHandler:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this class? It looks like it's an incomplete caching layer to me.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly what it is, in the next PR this handler manages the cache checking and calls to the sds client depending on what the cache returns


def __init__(self, client: sds.SDSClient):
if not client:
raise ValueError('sds client required')
self.sds_client = client

async def retrieve_mhs_attributes(self, org_code, interaction_id):
endpoint_details = await self.sds_client.get_mhs_details(org_code, interaction_id)
return endpoint_details
Loading