diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 4afbbad24..60e238da6 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -117,6 +117,7 @@ def emit(self, record: logging.LogRecord) -> None: REDIRECT_URL = "http://localhost:8020" CLIENT_ID = "dbt-databricks" SCOPES = ["all-apis", "offline_access"] +MAX_NT_PASSWORD_SIZE = 1280 # toggle for session managements that minimizes the number of sessions opened/closed USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" @@ -385,7 +386,7 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia # optional branch. Try and keep going if it does not work try: # try to get cached credentials - credsdict = keyring.get_password("dbt-databricks", host) + credsdict = self.get_sharded_password("dbt-databricks", host) if credsdict: provider = SessionCredentials.from_dict(oauth_client, json.loads(credsdict)) @@ -396,7 +397,7 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia except Exception as e: logger.debug(e) # whatever it is, get rid of the cache - keyring.delete_password("dbt-databricks", host) + self.delete_sharded_password("dbt-databricks", host) # error with keyring. Maybe machine has no password persistency except Exception as e: @@ -410,7 +411,9 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia # save for later self._credentials_provider = provider.as_dict() try: - keyring.set_password("dbt-databricks", host, json.dumps(self._credentials_provider)) + self.set_sharded_password( + "dbt-databricks", host, json.dumps(self._credentials_provider) + ) # error with keyring. Maybe machine has no password persistency except Exception as e: logger.debug(e) @@ -421,6 +424,63 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia finally: self._lock.release() + def set_sharded_password(self, service_name: str, username: str, password: str) -> None: + max_size = MAX_NT_PASSWORD_SIZE + + # if not Windows or "small" password, stick to the default + if os.name != "nt" or len(password) < max_size: + keyring.set_password(service_name, username, password) + else: + logger.debug(f"password is {len(password)} characters, sharding it.") + + password_shards = [ + password[i : i + max_size] for i in range(0, len(password), max_size) + ] + shard_info = { + "sharded_password": True, + "shard_count": len(password_shards), + } + + # store the "shard info" as the "base" password + keyring.set_password(service_name, username, json.dumps(shard_info)) + # then store all shards with the shard number as postfix + for i, s in enumerate(password_shards): + keyring.set_password(service_name, f"{username}__{i}", s) + + def get_sharded_password(self, service_name: str, username: str) -> Optional[str]: + password = keyring.get_password(service_name, username) + + # check for "shard info" stored as json + try: + password_as_dict = json.loads(str(password)) + if password_as_dict.get("sharded_password"): + # if password was stored shared, reconstruct it + shard_count = int(password_as_dict.get("shard_count")) + + password = "" + for i in range(shard_count): + password += str(keyring.get_password(service_name, f"{username}__{i}")) + except ValueError: + pass + + return password + + def delete_sharded_password(self, service_name: str, username: str) -> None: + password = keyring.get_password(service_name, username) + + # check for "shard info" stored as json. If so delete all shards + try: + password_as_dict = json.loads(str(password)) + if password_as_dict.get("sharded_password"): + shard_count = int(password_as_dict.get("shard_count")) + for i in range(shard_count): + keyring.delete_password(service_name, f"{username}__{i}") + except ValueError: + pass + + # delete "base" password + keyring.delete_password(service_name, username) + def _provider_from_dict(self) -> Optional[TCredentialProvider]: if self.token: return token_auth.from_dict(self._credentials_provider) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index c887b547c..06daf071d 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -1,5 +1,6 @@ import unittest from dbt.adapters.databricks.connections import DatabricksCredentials +import keyring.backend import pytest @@ -71,3 +72,94 @@ def test_token(self): headers_fn2 = provider_b() headers2 = headers_fn2() self.assertEqual(headers, headers2) + + +class TestShardedPassword(unittest.TestCase): + def test_store_and_delete_short_password(self): + # set the keyring to mock class + keyring.set_keyring(MockKeyring()) + + service = "dbt-databricks" + host = "my.cloud.databricks.com" + long_password = "x" * 10 + + creds = DatabricksCredentials( + host=host, token="foo", database="andre", http_path="http://foo", schema="dbt" + ) + creds.set_sharded_password(service, host, long_password) + + retrieved_password = creds.get_sharded_password(service, host) + self.assertEqual(long_password, retrieved_password) + + # delete password + creds.delete_sharded_password(service, host) + retrieved_password = creds.get_sharded_password(service, host) + self.assertIsNone(retrieved_password) + + def test_store_and_delete_long_password(self): + # set the keyring to mock class + keyring.set_keyring(MockKeyring()) + + service = "dbt-databricks" + host = "my.cloud.databricks.com" + long_password = "x" * 3000 + + creds = DatabricksCredentials( + host=host, token="foo", database="andre", http_path="http://foo", schema="dbt" + ) + creds.set_sharded_password(service, host, long_password) + + retrieved_password = creds.get_sharded_password(service, host) + self.assertEqual(long_password, retrieved_password) + + # delete password + creds.delete_sharded_password(service, host) + retrieved_password = creds.get_sharded_password(service, host) + self.assertIsNone(retrieved_password) + + +class MockKeyring(keyring.backend.KeyringBackend): + def __init__(self): + self.file_location = self._generate_test_root_dir() + + def priority(self): + return 1 + + def _generate_test_root_dir(self): + import tempfile + + return tempfile.mkdtemp(prefix="dbt-unit-test-") + + def file_path(self, servicename, username): + from os.path import join + + file_location = self.file_location + file_name = f"{servicename}_{username}.txt" + return join(file_location, file_name) + + def set_password(self, servicename, username, password): + file_path = self.file_path(servicename, username) + + with open(file_path, "w") as file: + file.write(password) + + def get_password(self, servicename, username): + import os + + file_path = self.file_path(servicename, username) + if not os.path.exists(file_path): + return None + + with open(file_path, "r") as file: + password = file.read() + + return password + + def delete_password(self, servicename, username): + import os + + file_path = self.file_path(servicename, username) + if not os.path.exists(file_path): + return None + + os.remove(file_path)