diff --git a/backend/apps/common/model_fields.py b/backend/apps/common/model_fields.py index 67a8b09840..dea5923e48 100644 --- a/backend/apps/common/model_fields.py +++ b/backend/apps/common/model_fields.py @@ -3,7 +3,7 @@ from django.conf import settings from django.db import models -from apps.nest.clients.kms import kms_client +from apps.nest.clients.kms import KmsClient KMS_ERROR_MESSAGE = "AWS KMS is not enabled." STR_ERROR_MESSAGE = "Value must be a string to encrypt." @@ -23,7 +23,7 @@ def get_prep_value(self, value): return None if not isinstance(value, str): raise TypeError(STR_ERROR_MESSAGE) - ciphertext = kms_client.encrypt(value) + ciphertext = KmsClient().encrypt(value) return super().get_prep_value(ciphertext) def _get_from_db_helper(self, value): @@ -36,7 +36,7 @@ def _get_from_db_helper(self, value): return value if not isinstance(value, (bytes, bytearray, memoryview)): raise TypeError(BYTES_ERROR_MESSAGE) - return kms_client.decrypt(bytes(value)) if value else None + return KmsClient().decrypt(bytes(value)) if value else None def from_db_value(self, value, expression, connection): """Decrypt the value from the database using KMS.""" diff --git a/backend/apps/nest/clients/kms.py b/backend/apps/nest/clients/kms.py index 8ad13b683a..2dd1614cb3 100644 --- a/backend/apps/nest/clients/kms.py +++ b/backend/apps/nest/clients/kms.py @@ -1,5 +1,7 @@ """AWS KMS client.""" +from threading import Lock + import boto3 from django.conf import settings @@ -7,8 +9,20 @@ class KmsClient: """AWS KMS Client.""" + _lock = Lock() + + def __new__(cls): + """Create a new instance of KmsClient.""" + if not hasattr(cls, "instance"): + with cls._lock: + if not hasattr(cls, "instance"): + cls.instance = super().__new__(cls) + return cls.instance + def __init__(self): """Initialize the KMS client.""" + if getattr(self, "client", None) is not None: + return self.client = boto3.client( "kms", aws_access_key_id=settings.AWS_ACCESS_KEY_ID, @@ -28,6 +42,3 @@ def encrypt(self, text: str) -> bytes: KeyId=settings.AWS_KMS_KEY_ID, Plaintext=text.encode("utf-8"), )["CiphertextBlob"] - - -kms_client = KmsClient() diff --git a/backend/tests/apps/nest/clients/kms_test.py b/backend/tests/apps/nest/clients/kms_test.py index 838c31d289..07f79fdbcf 100644 --- a/backend/tests/apps/nest/clients/kms_test.py +++ b/backend/tests/apps/nest/clients/kms_test.py @@ -18,6 +18,8 @@ def setup(self): self.mock_boto3_client = mock_boto3_client self.mock_kms_client = Mock() self.mock_boto3_client.return_value = self.mock_kms_client + if hasattr(KmsClient, "instance"): + del KmsClient.instance self.kms_client = KmsClient() @override_settings(AWS_REGION="us-west-2", AWS_KMS_KEY_ID="test_key_id") @@ -38,3 +40,11 @@ def test_decrypt(self): decrypted = self.kms_client.decrypt(b"encrypted") assert decrypted == "test" self.mock_kms_client.decrypt.assert_called_once_with(CiphertextBlob=b"encrypted") + + @override_settings(AWS_REGION="us-west-2", AWS_KMS_KEY_ID="test_key_id") + def test_singleton(self): + """Test singleton behavior.""" + instance1 = KmsClient() + instance2 = KmsClient() + assert instance1 is instance2 + assert instance1.client is instance2.client