Skip to content

Commit

Permalink
Merge pull request #116 from fishtown-analytics/odbc-driver-support
Browse files Browse the repository at this point in the history
Support for ODBC driver connection type
  • Loading branch information
Kyle Wigley authored Nov 6, 2020
2 parents e7d73ef + f882e15 commit 1bbe718
Show file tree
Hide file tree
Showing 13 changed files with 424 additions and 49 deletions.
34 changes: 28 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jobs:
environment:
DBT_INVOCATION_ENV: circle
docker:
- image: fishtownanalytics/test-container:9
- image: fishtownanalytics/test-container:10
steps:
- checkout
- run: tox -e flake8,unit
Expand All @@ -14,7 +14,7 @@ jobs:
environment:
DBT_INVOCATION_ENV: circle
docker:
- image: fishtownanalytics/test-container:9
- image: fishtownanalytics/test-container:10
- image: godatadriven/spark:2
environment:
WAIT_FOR: localhost:5432
Expand Down Expand Up @@ -46,16 +46,35 @@ jobs:
- store_artifacts:
path: ./logs

integration-spark-databricks:
integration-spark-databricks-http:
environment:
DBT_INVOCATION_ENV: circle
docker:
- image: fishtownanalytics/test-container:9
- image: fishtownanalytics/test-container:10
steps:
- checkout
- run:
name: Run integration tests
command: tox -e integration-spark-databricks
command: tox -e integration-spark-databricks-http
no_output_timeout: 1h
- store_artifacts:
path: ./logs

integration-spark-databricks-odbc:
environment:
DBT_INVOCATION_ENV: circle
ODBC_DRIVER: Simba # TODO: move env var to Docker image
docker:
# image based on `fishtownanalytics/test-container` w/ Simba ODBC Spark driver installed
- image: 828731156495.dkr.ecr.us-east-1.amazonaws.com/dbt-spark-odbc-test-container:latest
aws_auth:
aws_access_key_id: $AWS_ACCESS_KEY_ID_STAGING
aws_secret_access_key: $AWS_SECRET_ACCESS_KEY_STAGING
steps:
- checkout
- run:
name: Run integration tests
command: tox -e integration-spark-databricks-odbc-cluster,integration-spark-databricks-odbc-sql-endpoint
no_output_timeout: 1h
- store_artifacts:
path: ./logs
Expand All @@ -68,6 +87,9 @@ workflows:
- integration-spark-thrift:
requires:
- unit
- integration-spark-databricks:
- integration-spark-databricks-http:
requires:
- unit
- integration-spark-databricks-odbc:
requires:
- unit
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ dist/
dbt-integration-tests
test/integration/.user.yml
.DS_Store
.vscode
78 changes: 56 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,21 @@ For more information on using Spark with dbt, consult the dbt documentation:
### Installation
This plugin can be installed via pip:

```
```bash
# Install dbt-spark from PyPi:
$ pip install dbt-spark
```

dbt-spark also supports connections via ODBC driver, but it requires [`pyodbc`](https://github.com/mkleehammer/pyodbc). You can install it seperately or via pip as well:

```bash
# Install dbt-spark w/ pyodbc from PyPi:
$ pip install "dbt-spark[ODBC]"
```

See https://github.com/mkleehammer/pyodbc/wiki/Install for more info about installing `pyodbc`.


### Configuring your profile

**Connection Method**
Expand All @@ -40,18 +50,20 @@ Connections can be made to Spark in two different modes. The `http` mode is used

A dbt profile can be configured to run against Spark using the following configuration:

| Option | Description | Required? | Example |
|---------|----------------------------------------------------|-------------------------|--------------------------|
| method | Specify the connection method (`thrift` or `http`) | Required | `http` |
| schema | Specify the schema (database) to build models into | Required | `analytics` |
| host | The hostname to connect to | Required | `yourorg.sparkhost.com` |
| port | The port to connect to the host on | Optional (default: 443 for `http`, 10001 for `thrift`) | `443` |
| token | The token to use for authenticating to the cluster | Required for `http` | `abc123` |
| organization | The id of the Azure Databricks workspace being used; only for Azure Databricks | See Databricks Note | `1234567891234567` |
| cluster | The name of the cluster to connect to | Required for `http` | `01234-23423-coffeetime` |
| user | The username to use to connect to the cluster | Optional | `hadoop` |
| connect_timeout | The number of seconds to wait before retrying to connect to a Pending Spark cluster | Optional (default: 10) | `60` |
| connect_retries | The number of times to try connecting to a Pending Spark cluster before giving up | Optional (default: 0) | `5` |
| Option | Description | Required? | Example |
| --------------- | ----------------------------------------------------------------------------------- | ------------------------------------------------------------------ | ---------------------------------------------- |
| method | Specify the connection method (`thrift` or `http` or `odbc`) | Required | `http` |
| schema | Specify the schema (database) to build models into | Required | `analytics` |
| host | The hostname to connect to | Required | `yourorg.sparkhost.com` |
| port | The port to connect to the host on | Optional (default: 443 for `http` and `odbc`, 10001 for `thrift`) | `443` |
| token | The token to use for authenticating to the cluster | Required for `http` and `odbc` | `abc123` |
| organization | The id of the Azure Databricks workspace being used; only for Azure Databricks | See Databricks Note | `1234567891234567` |
| cluster | The name of the cluster to connect to | Required for `http` and `odbc` if connecting to a specific cluster | `01234-23423-coffeetime` |
| endpoint | The ID of the SQL endpoint to connect to | Required for `odbc` if connecting to SQL endpoint | `1234567891234a` |
| driver | Path of ODBC driver installed or name of ODBC DSN configured | Required for `odbc` | `/opt/simba/spark/lib/64/libsparkodbc_sb64.so` |
| user | The username to use to connect to the cluster | Optional | `hadoop` |
| connect_timeout | The number of seconds to wait before retrying to connect to a Pending Spark cluster | Optional (default: 10) | `60` |
| connect_retries | The number of times to try connecting to a Pending Spark cluster before giving up | Optional (default: 0) | `5` |

**Databricks Note**

Expand Down Expand Up @@ -104,6 +116,28 @@ your_profile_name:
connect_timeout: 60
```

**ODBC connection**
```
your_profile_name:
target: dev
outputs:
dev:
method: odbc
type: spark
schema: analytics
host: yourorg.sparkhost.com
organization: 1234567891234567 # Azure Databricks ONLY
port: 443
token: abc123
# one of:
cluster: 01234-23423-coffeetime
endpoint: coffee01234time
driver: path/to/driver
connect_retries: 5 # cluster only
connect_timeout: 60 # cluster only
```


### Usage Notes
Expand All @@ -113,15 +147,15 @@ your_profile_name:
The following configurations can be supplied to models run with the dbt-spark plugin:


| Option | Description | Required? | Example |
|---------|----------------------------------------------------|-------------------------|--------------------------|
| file_format | The file format to use when creating tables (`parquet`, `delta`, `csv`, `json`, `text`, `jdbc`, `orc`, `hive` or `libsvm`). | Optional | `parquet`|
| location_root | The created table uses the specified directory to store its data. The table alias is appended to it. | Optional | `/mnt/root` |
| partition_by | Partition the created table by the specified columns. A directory is created for each partition. | Optional | `partition_1` |
| clustered_by | Each partition in the created table will be split into a fixed number of buckets by the specified columns. | Optional | `cluster_1` |
| buckets | The number of buckets to create while clustering | Required if `clustered_by` is specified | `8` |
| incremental_strategy | The strategy to use for incremental models (`insert_overwrite` or `merge`). Note `merge` requires `file_format` = `delta` and `unique_key` to be specified. | Optional (default: `insert_overwrite`) | `merge` |
| persist_docs | Whether dbt should include the model description as a table `comment` | Optional | `{'relation': true}` |
| Option | Description | Required? | Example |
| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------- | -------------------- |
| file_format | The file format to use when creating tables (`parquet`, `delta`, `csv`, `json`, `text`, `jdbc`, `orc`, `hive` or `libsvm`). | Optional | `parquet` |
| location_root | The created table uses the specified directory to store its data. The table alias is appended to it. | Optional | `/mnt/root` |
| partition_by | Partition the created table by the specified columns. A directory is created for each partition. | Optional | `partition_1` |
| clustered_by | Each partition in the created table will be split into a fixed number of buckets by the specified columns. | Optional | `cluster_1` |
| buckets | The number of buckets to create while clustering | Required if `clustered_by` is specified | `8` |
| incremental_strategy | The strategy to use for incremental models (`insert_overwrite` or `merge`). Note `merge` requires `file_format` = `delta` and `unique_key` to be specified. | Optional (default: `insert_overwrite`) | `merge` |
| persist_docs | Whether dbt should include the model description as a table `comment` | Optional | `{'relation': true}` |


**Incremental Models**
Expand Down
105 changes: 99 additions & 6 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@
from dbt.contracts.connection import ConnectionState
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import DECIMALS
from dbt.adapters.spark import __version__

from TCLIService.ttypes import TOperationState as ThriftState
from thrift.transport import THttpClient
from pyhive import hive
try:
import pyodbc
except ImportError:
pyodbc = None
from datetime import datetime
import sqlparams

from hologram.helpers import StrEnum
from dataclasses import dataclass
Expand All @@ -22,9 +28,14 @@
NUMBERS = DECIMALS + (int, float)


def _build_odbc_connnection_string(**kwargs) -> str:
return ";".join([f"{k}={v}" for k, v in kwargs.items()])


class SparkConnectionMethod(StrEnum):
THRIFT = 'thrift'
HTTP = 'http'
ODBC = 'odbc'


@dataclass
Expand All @@ -33,7 +44,9 @@ class SparkCredentials(Credentials):
method: SparkConnectionMethod
schema: str
database: Optional[str]
driver: Optional[str] = None
cluster: Optional[str] = None
endpoint: Optional[str] = None
token: Optional[str] = None
user: Optional[str] = None
port: int = 443
Expand All @@ -57,15 +70,34 @@ def __post_init__(self):
)
self.database = None

if self.method == SparkConnectionMethod.ODBC and pyodbc is None:
raise dbt.exceptions.RuntimeException(
f"{self.method} connection method requires "
"additional dependencies. \n"
"Install the additional required dependencies with "
"`pip install dbt-spark[ODBC]`"
)

if (
self.method == SparkConnectionMethod.ODBC and
self.cluster and
self.endpoint
):
raise dbt.exceptions.RuntimeException(
"`cluster` and `endpoint` cannot both be set when"
f" using {self.method} method to connect to Spark"
)

@property
def type(self):
return 'spark'

def _connection_keys(self):
return 'host', 'port', 'cluster', 'schema', 'organization'
return ('host', 'port', 'cluster',
'endpoint', 'schema', 'organization')


class ConnectionWrapper(object):
class PyhiveConnectionWrapper(object):
"""Wrap a Spark connection in a way that no-ops transactions"""
# https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html # noqa

Expand Down Expand Up @@ -177,11 +209,28 @@ def description(self):
return self._cursor.description


class PyodbcConnectionWrapper(PyhiveConnectionWrapper):

def execute(self, sql, bindings=None):
if sql.strip().endswith(";"):
sql = sql.strip()[:-1]
# pyodbc does not handle a None type binding!
if bindings is None:
self._cursor.execute(sql)
else:
# pyodbc only supports `qmark` sql params!
query = sqlparams.SQLParams('format', 'qmark')
sql, bindings = query.format(sql, bindings)
self._cursor.execute(sql, *bindings)


class SparkConnectionManager(SQLConnectionManager):
TYPE = 'spark'

SPARK_CLUSTER_HTTP_PATH = "/sql/protocolv1/o/{organization}/{cluster}"
SPARK_SQL_ENDPOINT_HTTP_PATH = "/sql/1.0/endpoints/{endpoint}"
SPARK_CONNECTION_URL = (
"https://{host}:{port}/sql/protocolv1/o/{organization}/{cluster}"
"https://{host}:{port}" + SPARK_CLUSTER_HTTP_PATH
)

@contextmanager
Expand Down Expand Up @@ -243,7 +292,7 @@ def open(cls, connection):

for i in range(1 + creds.connect_retries):
try:
if creds.method == 'http':
if creds.method == SparkConnectionMethod.HTTP:
cls.validate_creds(creds, ['token', 'host', 'port',
'cluster', 'organization'])

Expand All @@ -265,7 +314,8 @@ def open(cls, connection):
})

conn = hive.connect(thrift_transport=transport)
elif creds.method == 'thrift':
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.THRIFT:
cls.validate_creds(creds,
['host', 'port', 'user', 'schema'])

Expand All @@ -274,6 +324,50 @@ def open(cls, connection):
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name) # noqa
handle = PyhiveConnectionWrapper(conn)
elif creds.method == SparkConnectionMethod.ODBC:
http_path = None
if creds.cluster is not None:
required_fields = ['driver', 'host', 'port', 'token',
'organization', 'cluster']
http_path = cls.SPARK_CLUSTER_HTTP_PATH.format(
organization=creds.organization,
cluster=creds.cluster
)
elif creds.endpoint is not None:
required_fields = ['driver', 'host', 'port', 'token',
'endpoint']
http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format(
endpoint=creds.endpoint
)
else:
raise dbt.exceptions.DbtProfileError(
"Either `cluster` or `endpoint` must set when"
" using the odbc method to connect to Spark"
)

cls.validate_creds(creds, required_fields)

dbt_spark_version = __version__.version
user_agent_entry = f"fishtown-analytics-dbt-spark/{dbt_spark_version} (Databricks)" # noqa

# https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm
connection_str = _build_odbc_connnection_string(
DRIVER=creds.driver,
HOST=creds.host,
PORT=creds.port,
UID="token",
PWD=creds.token,
HTTPPath=http_path,
AuthMech=3,
SparkServerType=3,
ThriftTransport=2,
SSL=1,
UserAgentEntry=user_agent_entry,
)

conn = pyodbc.connect(connection_str, autocommit=True)
handle = PyodbcConnectionWrapper(conn)
else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down Expand Up @@ -304,7 +398,6 @@ def open(cls, connection):
else:
raise exc

handle = ConnectionWrapper(conn)
connection.handle = handle
connection.state = ConnectionState.OPEN
return connection
Expand Down
2 changes: 1 addition & 1 deletion dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ pytest-xdist>=2.1.0,<3
flaky>=3.5.3,<4

# Test requirements
pytest-dbt-adapter==0.2.0
pytest-dbt-adapter==0.3.0
sasl==0.2.1
thrift_sasl==0.4.1
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dbt-core==0.18.0
PyHive[hive]>=0.6.0,<0.7.0
pyodbc>=4.0.30
sqlparams>=3.0.0
thrift>=0.11.0,<0.12.0
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def _dbt_spark_version():
install_requires=[
f'dbt-core=={dbt_version}',
'PyHive[hive]>=0.6.0,<0.7.0',
'thrift>=0.11.0,<0.12.0',
]
'sqlparams>=3.0.0',
'thrift>=0.11.0,<0.12.0'
],
extra_requires={
"ODBC": ['pyodbc>=4.0.30'],
}
)
Loading

0 comments on commit 1bbe718

Please sign in to comment.