diff --git a/msal_extensions/persistence.py b/msal_extensions/persistence.py index e4b6c66..dd8b890 100644 --- a/msal_extensions/persistence.py +++ b/msal_extensions/persistence.py @@ -9,6 +9,7 @@ import abc import os import errno +import hashlib import logging import sys try: @@ -50,6 +51,9 @@ def _mkdir_p(path): else: raise +def _auto_hash(input_string): + return hashlib.sha256(input_string.encode('utf-8')).hexdigest() + # We do not aim to wrap every os-specific exception. # Here we define only the most common one, @@ -197,19 +201,18 @@ class KeychainPersistence(BasePersistence): and protected by native Keychain libraries on OSX""" is_encrypted = True - def __init__(self, signal_location, service_name, account_name): + def __init__(self, signal_location, service_name=None, account_name=None): """Initialization could fail due to unsatisfied dependency. :param signal_location: See :func:`persistence.LibsecretPersistence.__init__` """ - if not (service_name and account_name): # It would hang on OSX - raise ValueError("service_name and account_name are required") from .osx import Keychain, KeychainError # pylint: disable=import-outside-toplevel self._file_persistence = FilePersistence(signal_location) # Favor composition self._Keychain = Keychain # pylint: disable=invalid-name self._KeychainError = KeychainError # pylint: disable=invalid-name - self._service_name = service_name - self._account_name = account_name + default_service_name = "msal-extensions" # This is also our package name + self._service_name = service_name or default_service_name + self._account_name = account_name or _auto_hash(signal_location) def save(self, content): with self._Keychain() as locker: @@ -247,7 +250,7 @@ class LibsecretPersistence(BasePersistence): and protected by native libsecret libraries on Linux""" is_encrypted = True - def __init__(self, signal_location, schema_name, attributes, **kwargs): + def __init__(self, signal_location, schema_name=None, attributes=None, **kwargs): """Initialization could fail due to unsatisfied dependency. :param string signal_location: @@ -262,7 +265,8 @@ def __init__(self, signal_location, schema_name, attributes, **kwargs): from .libsecret import ( # This uncertain import is deferred till runtime LibSecretAgent, trial_run) trial_run() - self._agent = LibSecretAgent(schema_name, attributes, **kwargs) + self._agent = LibSecretAgent( + schema_name or _auto_hash(signal_location), attributes or {}, **kwargs) self._file_persistence = FilePersistence(signal_location) # Favor composition def save(self, content): diff --git a/sample/persistence_sample.py b/sample/persistence_sample.py index 74074d3..f5c8c06 100644 --- a/sample/persistence_sample.py +++ b/sample/persistence_sample.py @@ -10,7 +10,7 @@ def build_persistence(location, fallback_to_plaintext=False): if sys.platform.startswith('win'): return FilePersistenceWithDataProtection(location) if sys.platform.startswith('darwin'): - return KeychainPersistence(location, "my_service_name", "my_account_name") + return KeychainPersistence(location) if sys.platform.startswith('linux'): try: return LibsecretPersistence( @@ -21,8 +21,6 @@ def build_persistence(location, fallback_to_plaintext=False): # unless there would frequently be a desktop session and # a remote ssh session being active simultaneously. location, - schema_name="my_schema_name", - attributes={"my_attr1": "foo", "my_attr2": "bar"}, ) except: # pylint: disable=bare-except if not fallback_to_plaintext: @@ -31,6 +29,7 @@ def build_persistence(location, fallback_to_plaintext=False): return FilePersistence(location) persistence = build_persistence("storage.bin", fallback_to_plaintext=False) +print("Type of persistence: {}".format(persistence.__class__.__name__)) print("Is this persistence encrypted?", persistence.is_encrypted) data = { # It can be anything, here we demonstrate an arbitrary json object diff --git a/sample/token_cache_sample.py b/sample/token_cache_sample.py index b48e19d..7210efa 100644 --- a/sample/token_cache_sample.py +++ b/sample/token_cache_sample.py @@ -10,7 +10,7 @@ def build_persistence(location, fallback_to_plaintext=False): if sys.platform.startswith('win'): return FilePersistenceWithDataProtection(location) if sys.platform.startswith('darwin'): - return KeychainPersistence(location, "my_service_name", "my_account_name") + return KeychainPersistence(location) if sys.platform.startswith('linux'): try: return LibsecretPersistence( @@ -21,8 +21,6 @@ def build_persistence(location, fallback_to_plaintext=False): # unless there would frequently be a desktop session and # a remote ssh session being active simultaneously. location, - schema_name="my_schema_name", - attributes={"my_attr1": "foo", "my_attr2": "bar"}, ) except: # pylint: disable=bare-except if not fallback_to_plaintext: @@ -31,6 +29,7 @@ def build_persistence(location, fallback_to_plaintext=False): return FilePersistence(location) persistence = build_persistence("token_cache.bin") +print("Type of persistence: {}".format(persistence.__class__.__name__)) print("Is this persistence encrypted?", persistence.is_encrypted) cache = PersistedTokenCache(persistence) diff --git a/tests/test_persistence.py b/tests/test_persistence.py index bbbe155..dfc0963 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -54,8 +54,7 @@ def test_nonexistent_file_persistence_with_data_protection(temp_location): not sys.platform.startswith('darwin'), reason="Requires OSX. Whether running on TRAVIS CI does not seem to matter.") def test_keychain_persistence(temp_location): - _test_persistence_roundtrip(KeychainPersistence( - temp_location, "my_service_name", "my_account_name")) + _test_persistence_roundtrip(KeychainPersistence(temp_location)) @pytest.mark.skipif( not sys.platform.startswith('darwin'), @@ -69,11 +68,7 @@ def test_nonexistent_keychain_persistence(temp_location): is_running_on_travis_ci or not sys.platform.startswith('linux'), reason="Requires Linux Desktop. Headless or SSH session won't work.") def test_libsecret_persistence(temp_location): - _test_persistence_roundtrip(LibsecretPersistence( - temp_location, - "my_schema_name", - {"my_attr_1": "foo", "my_attr_2": "bar"}, - )) + _test_persistence_roundtrip(LibsecretPersistence(temp_location)) @pytest.mark.skipif( is_running_on_travis_ci or not sys.platform.startswith('linux'),