Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def emit(self, record: logging.LogRecord) -> None:
REDIRECT_URL = "http://localhost:8020"
CLIENT_ID = "dbt-databricks"
SCOPES = ["all-apis", "offline_access"]
MAX_NT_PASSWORD_SIZE = 1280

# toggle for session managements that minimizes the number of sessions opened/closed
USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"
Expand Down Expand Up @@ -385,7 +386,7 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia
# optional branch. Try and keep going if it does not work
try:
# try to get cached credentials
credsdict = keyring.get_password("dbt-databricks", host)
credsdict = self.get_sharded_password("dbt-databricks", host)

if credsdict:
provider = SessionCredentials.from_dict(oauth_client, json.loads(credsdict))
Expand All @@ -396,7 +397,7 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia
except Exception as e:
logger.debug(e)
# whatever it is, get rid of the cache
keyring.delete_password("dbt-databricks", host)
self.delete_sharded_password("dbt-databricks", host)

# error with keyring. Maybe machine has no password persistency
except Exception as e:
Expand All @@ -410,7 +411,9 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia
# save for later
self._credentials_provider = provider.as_dict()
try:
keyring.set_password("dbt-databricks", host, json.dumps(self._credentials_provider))
self.set_sharded_password(
"dbt-databricks", host, json.dumps(self._credentials_provider)
)
# error with keyring. Maybe machine has no password persistency
except Exception as e:
logger.debug(e)
Expand All @@ -421,6 +424,63 @@ def authenticate(self, in_provider: Optional[TCredentialProvider]) -> TCredentia
finally:
self._lock.release()

def set_sharded_password(self, service_name: str, username: str, password: str) -> None:
max_size = MAX_NT_PASSWORD_SIZE

# if not Windows or "small" password, stick to the default
if os.name != "nt" or len(password) < max_size:
keyring.set_password(service_name, username, password)
else:
logger.debug(f"password is {len(password)} characters, sharding it.")

password_shards = [
password[i : i + max_size] for i in range(0, len(password), max_size)
]
shard_info = {
"sharded_password": True,
"shard_count": len(password_shards),
}

# store the "shard info" as the "base" password
keyring.set_password(service_name, username, json.dumps(shard_info))
# then store all shards with the shard number as postfix
for i, s in enumerate(password_shards):
keyring.set_password(service_name, f"{username}__{i}", s)

def get_sharded_password(self, service_name: str, username: str) -> Optional[str]:
password = keyring.get_password(service_name, username)

# check for "shard info" stored as json
try:
password_as_dict = json.loads(str(password))
if password_as_dict.get("sharded_password"):
# if password was stored shared, reconstruct it
shard_count = int(password_as_dict.get("shard_count"))

password = ""
for i in range(shard_count):
password += str(keyring.get_password(service_name, f"{username}__{i}"))
except ValueError:
pass

return password

def delete_sharded_password(self, service_name: str, username: str) -> None:
password = keyring.get_password(service_name, username)

# check for "shard info" stored as json. If so delete all shards
try:
password_as_dict = json.loads(str(password))
if password_as_dict.get("sharded_password"):
shard_count = int(password_as_dict.get("shard_count"))
for i in range(shard_count):
keyring.delete_password(service_name, f"{username}__{i}")
except ValueError:
pass

# delete "base" password
keyring.delete_password(service_name, username)

def _provider_from_dict(self) -> Optional[TCredentialProvider]:
if self.token:
return token_auth.from_dict(self._credentials_provider)
Expand Down
92 changes: 92 additions & 0 deletions tests/unit/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from dbt.adapters.databricks.connections import DatabricksCredentials
import keyring.backend
import pytest


Expand Down Expand Up @@ -71,3 +72,94 @@ def test_token(self):
headers_fn2 = provider_b()
headers2 = headers_fn2()
self.assertEqual(headers, headers2)


class TestShardedPassword(unittest.TestCase):
def test_store_and_delete_short_password(self):
# set the keyring to mock class
keyring.set_keyring(MockKeyring())

service = "dbt-databricks"
host = "my.cloud.databricks.com"
long_password = "x" * 10

creds = DatabricksCredentials(
host=host, token="foo", database="andre", http_path="http://foo", schema="dbt"
)
creds.set_sharded_password(service, host, long_password)

retrieved_password = creds.get_sharded_password(service, host)
self.assertEqual(long_password, retrieved_password)

# delete password
creds.delete_sharded_password(service, host)
retrieved_password = creds.get_sharded_password(service, host)
self.assertIsNone(retrieved_password)

def test_store_and_delete_long_password(self):
# set the keyring to mock class
keyring.set_keyring(MockKeyring())

service = "dbt-databricks"
host = "my.cloud.databricks.com"
long_password = "x" * 3000

creds = DatabricksCredentials(
host=host, token="foo", database="andre", http_path="http://foo", schema="dbt"
)
creds.set_sharded_password(service, host, long_password)

retrieved_password = creds.get_sharded_password(service, host)
self.assertEqual(long_password, retrieved_password)

# delete password
creds.delete_sharded_password(service, host)
retrieved_password = creds.get_sharded_password(service, host)
self.assertIsNone(retrieved_password)


class MockKeyring(keyring.backend.KeyringBackend):
def __init__(self):
self.file_location = self._generate_test_root_dir()

def priority(self):
return 1

def _generate_test_root_dir(self):
import tempfile

return tempfile.mkdtemp(prefix="dbt-unit-test-")

def file_path(self, servicename, username):
from os.path import join

file_location = self.file_location
file_name = f"{servicename}_{username}.txt"
return join(file_location, file_name)

def set_password(self, servicename, username, password):
file_path = self.file_path(servicename, username)

with open(file_path, "w") as file:
file.write(password)

def get_password(self, servicename, username):
import os

file_path = self.file_path(servicename, username)
if not os.path.exists(file_path):
return None

with open(file_path, "r") as file:
password = file.read()

return password

def delete_password(self, servicename, username):
import os

file_path = self.file_path(servicename, username)
if not os.path.exists(file_path):
return None

os.remove(file_path)