Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ To run all of these examples you can clone the entire repository to your disk. O
this example the string `ExamplePartnerTag` will be added to the the user agent on every request.
- **`staging_ingestion.py`** shows how the connector handles Databricks' experimental staging ingestion commands `GET`, `PUT`, and `REMOVE`.
- **`sqlalchemy.py`** shows a basic example of connecting to Databricks with [SQLAlchemy](https://www.sqlalchemy.org/).
- **`custom_cred_provider.py`** shows how to pass a custom credential provider to bypass connector authentication. Please install databricks-sdk prior to running this example.
29 changes: 29 additions & 0 deletions examples/custom_cred_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# please install databricks-sdk prior to running this example.

from databricks import sql
from databricks.sdk.oauth import OAuthClient
import os

oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
client_id=os.getenv("DATABRICKS_CLIENT_ID"),
client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
redirect_url=os.getenv("APP_REDIRECT_URL"),
scopes=['all-apis', 'offline_access'])

consent = oauth_client.initiate_consent()

creds = consent.launch_external_browser()

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
credentials_provider=creds) as connection:

for x in range(1, 5):
cursor = connection.cursor()
cursor.execute('SELECT 1+1')
result = cursor.fetchall()
for row in result:
print(row)
cursor.close()

connection.close()
6 changes: 6 additions & 0 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AuthProvider,
AccessTokenAuthProvider,
BasicAuthProvider,
ExternalAuthProvider,
DatabricksOAuthProvider,
)
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
Expand All @@ -30,6 +31,7 @@ def __init__(
use_cert_as_auth: str = None,
tls_client_cert_file: str = None,
oauth_persistence=None,
credentials_provider=None,
):
self.hostname = hostname
self.username = username
Expand All @@ -42,9 +44,12 @@ def __init__(
self.use_cert_as_auth = use_cert_as_auth
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider


def get_auth_provider(cfg: ClientContext):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
if cfg.auth_type == AuthType.DATABRICKS_OAUTH.value:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
Expand Down Expand Up @@ -94,5 +99,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
else PYSQL_OAUTH_REDIRECT_PORT_RANGE,
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
)
return get_auth_provider(cfg)
29 changes: 28 additions & 1 deletion src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import base64
import logging
from typing import Dict, List
from typing import Callable, Dict, List

from databricks.sql.auth.oauth import OAuthManager

Expand All @@ -14,6 +15,22 @@ def add_headers(self, request_headers: Dict[str, str]):
pass


HeaderFactory = Callable[[], Dict[str, str]]

# In order to keep compatibility with SDK
class CredentialsProvider(abc.ABC):
"""CredentialsProvider is the protocol (call-side interface)
for authenticating requests to Databricks REST APIs"""

@abc.abstractmethod
def auth_type(self) -> str:
...

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> HeaderFactory:
...


# Private API: this is an evolving interface and it will change in the future.
# Please must not depend on it in your applications.
class AccessTokenAuthProvider(AuthProvider):
Expand Down Expand Up @@ -120,3 +137,13 @@ def _update_token_if_expired(self):
except Exception as e:
logging.error(f"unexpected error in oauth token update", e, exc_info=True)
raise e


class ExternalAuthProvider(AuthProvider):
def __init__(self, credentials_provider: CredentialsProvider) -> None:
self._header_factory = credentials_provider()

def add_headers(self, request_headers: Dict[str, str]):
headers = self._header_factory()
for k, v in headers.items():
request_headers[k] = v
37 changes: 36 additions & 1 deletion tests/unit/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest

from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider
from databricks.sql.auth.auth import AccessTokenAuthProvider, BasicAuthProvider, AuthProvider, ExternalAuthProvider
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory


class Auth(unittest.TestCase):
Expand Down Expand Up @@ -37,6 +38,22 @@ def test_noop_auth_provider(self):
self.assertEqual(len(http_request.keys()), 1)
self.assertEqual(http_request['myKey'], 'myVal')

def test_external_provider(self):
class MyProvider(CredentialsProvider):
def auth_type(self) -> str:
return "mine"

def __call__(self, *args, **kwargs) -> HeaderFactory:
return lambda: {"foo": "bar"}

auth = ExternalAuthProvider(MyProvider())

http_request = {'myKey': 'myVal'}
auth.add_headers(http_request)
self.assertEqual(http_request['foo'], 'bar')
self.assertEqual(len(http_request.keys()), 2)
self.assertEqual(http_request['myKey'], 'myVal')

def test_get_python_sql_connector_auth_provider_access_token(self):
hostname = "moderakh-test.cloud.databricks.com"
kwargs = {'access_token': 'dpi123'}
Expand All @@ -47,6 +64,24 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
auth_provider.add_headers(headers)
self.assertEqual(headers['Authorization'], 'Bearer dpi123')

def test_get_python_sql_connector_auth_provider_external(self):

class MyProvider(CredentialsProvider):
def auth_type(self) -> str:
return "mine"

def __call__(self, *args, **kwargs) -> HeaderFactory:
return lambda: {"foo": "bar"}

hostname = "moderakh-test.cloud.databricks.com"
kwargs = {'credentials_provider': MyProvider()}
auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs)
self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider")

headers = {}
auth_provider.add_headers(headers)
self.assertEqual(headers['foo'], 'bar')

def test_get_python_sql_connector_auth_provider_username_password(self):
username = "moderakh"
password = "Elevate Databricks 123!!!"
Expand Down