Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backend/apps/common/model_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.conf import settings
from django.db import models

from apps.nest.clients.kms import kms_client
from apps.nest.clients.kms import KmsClient

KMS_ERROR_MESSAGE = "AWS KMS is not enabled."
STR_ERROR_MESSAGE = "Value must be a string to encrypt."
Expand All @@ -23,7 +23,7 @@ def get_prep_value(self, value):
return None
if not isinstance(value, str):
raise TypeError(STR_ERROR_MESSAGE)
ciphertext = kms_client.encrypt(value)
ciphertext = KmsClient().encrypt(value)
return super().get_prep_value(ciphertext)

def _get_from_db_helper(self, value):
Expand All @@ -36,7 +36,7 @@ def _get_from_db_helper(self, value):
return value
if not isinstance(value, (bytes, bytearray, memoryview)):
raise TypeError(BYTES_ERROR_MESSAGE)
return kms_client.decrypt(bytes(value)) if value else None
return KmsClient().decrypt(bytes(value)) if value else None

def from_db_value(self, value, expression, connection):
"""Decrypt the value from the database using KMS."""
Expand Down
17 changes: 14 additions & 3 deletions backend/apps/nest/clients/kms.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
"""AWS KMS client."""

from threading import Lock

import boto3
from django.conf import settings


class KmsClient:
"""AWS KMS Client."""

_lock = Lock()

def __new__(cls):
"""Create a new instance of KmsClient."""
if not hasattr(cls, "instance"):
with cls._lock:
if not hasattr(cls, "instance"):
cls.instance = super().__new__(cls)
return cls.instance

def __init__(self):
"""Initialize the KMS client."""
if getattr(self, "client", None) is not None:
return
self.client = boto3.client(
"kms",
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
Expand All @@ -28,6 +42,3 @@ def encrypt(self, text: str) -> bytes:
KeyId=settings.AWS_KMS_KEY_ID,
Plaintext=text.encode("utf-8"),
)["CiphertextBlob"]


kms_client = KmsClient()
10 changes: 10 additions & 0 deletions backend/tests/apps/nest/clients/kms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def setup(self):
self.mock_boto3_client = mock_boto3_client
self.mock_kms_client = Mock()
self.mock_boto3_client.return_value = self.mock_kms_client
if hasattr(KmsClient, "instance"):
del KmsClient.instance
self.kms_client = KmsClient()

@override_settings(AWS_REGION="us-west-2", AWS_KMS_KEY_ID="test_key_id")
Expand All @@ -38,3 +40,11 @@ def test_decrypt(self):
decrypted = self.kms_client.decrypt(b"encrypted")
assert decrypted == "test"
self.mock_kms_client.decrypt.assert_called_once_with(CiphertextBlob=b"encrypted")

@override_settings(AWS_REGION="us-west-2", AWS_KMS_KEY_ID="test_key_id")
def test_singleton(self):
"""Test singleton behavior."""
instance1 = KmsClient()
instance2 = KmsClient()
assert instance1 is instance2
assert instance1.client is instance2.client