Skip to content

Commit 210e207

Browse files
authored
Use singleton pattern in kms client (#2158)
* Use singleton * Add thread locking
1 parent 06b0527 commit 210e207

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

backend/apps/common/model_fields.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from django.conf import settings
44
from django.db import models
55

6-
from apps.nest.clients.kms import kms_client
6+
from apps.nest.clients.kms import KmsClient
77

88
KMS_ERROR_MESSAGE = "AWS KMS is not enabled."
99
STR_ERROR_MESSAGE = "Value must be a string to encrypt."
@@ -23,7 +23,7 @@ def get_prep_value(self, value):
2323
return None
2424
if not isinstance(value, str):
2525
raise TypeError(STR_ERROR_MESSAGE)
26-
ciphertext = kms_client.encrypt(value)
26+
ciphertext = KmsClient().encrypt(value)
2727
return super().get_prep_value(ciphertext)
2828

2929
def _get_from_db_helper(self, value):
@@ -36,7 +36,7 @@ def _get_from_db_helper(self, value):
3636
return value
3737
if not isinstance(value, (bytes, bytearray, memoryview)):
3838
raise TypeError(BYTES_ERROR_MESSAGE)
39-
return kms_client.decrypt(bytes(value)) if value else None
39+
return KmsClient().decrypt(bytes(value)) if value else None
4040

4141
def from_db_value(self, value, expression, connection):
4242
"""Decrypt the value from the database using KMS."""

backend/apps/nest/clients/kms.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
11
"""AWS KMS client."""
22

3+
from threading import Lock
4+
35
import boto3
46
from django.conf import settings
57

68

79
class KmsClient:
810
"""AWS KMS Client."""
911

12+
_lock = Lock()
13+
14+
def __new__(cls):
15+
"""Create a new instance of KmsClient."""
16+
if not hasattr(cls, "instance"):
17+
with cls._lock:
18+
if not hasattr(cls, "instance"):
19+
cls.instance = super().__new__(cls)
20+
return cls.instance
21+
1022
def __init__(self):
1123
"""Initialize the KMS client."""
24+
if getattr(self, "client", None) is not None:
25+
return
1226
self.client = boto3.client(
1327
"kms",
1428
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
@@ -28,6 +42,3 @@ def encrypt(self, text: str) -> bytes:
2842
KeyId=settings.AWS_KMS_KEY_ID,
2943
Plaintext=text.encode("utf-8"),
3044
)["CiphertextBlob"]
31-
32-
33-
kms_client = KmsClient()

backend/tests/apps/nest/clients/kms_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def setup(self):
1818
self.mock_boto3_client = mock_boto3_client
1919
self.mock_kms_client = Mock()
2020
self.mock_boto3_client.return_value = self.mock_kms_client
21+
if hasattr(KmsClient, "instance"):
22+
del KmsClient.instance
2123
self.kms_client = KmsClient()
2224

2325
@override_settings(AWS_REGION="us-west-2", AWS_KMS_KEY_ID="test_key_id")
@@ -38,3 +40,11 @@ def test_decrypt(self):
3840
decrypted = self.kms_client.decrypt(b"encrypted")
3941
assert decrypted == "test"
4042
self.mock_kms_client.decrypt.assert_called_once_with(CiphertextBlob=b"encrypted")
43+
44+
@override_settings(AWS_REGION="us-west-2", AWS_KMS_KEY_ID="test_key_id")
45+
def test_singleton(self):
46+
"""Test singleton behavior."""
47+
instance1 = KmsClient()
48+
instance2 = KmsClient()
49+
assert instance1 is instance2
50+
assert instance1.client is instance2.client

0 commit comments

Comments
 (0)