Skip to content

Commit d426ee2

Browse files
whitepowdertomasfarias
authored andcommitted
add DbtTrinoHook and tests
1 parent 8be2afe commit d426ee2

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

airflow_dbt_python/hooks/target.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,38 @@ class DbtSparkHook(DbtConnectionHook):
492492
),
493493
]
494494
conn_extra_params = []
495+
496+
497+
class DbtTrinoHook(DbtConnectionHook):
498+
"""A hook to interact with dbt using a Trino connection."""
499+
500+
conn_type = "trino"
501+
hook_name = "dbt Trino Hook"
502+
airflow_conn_types = (conn_type,)
503+
504+
conn_params = [
505+
"host",
506+
DbtConnectionParam("schema", default="public"),
507+
DbtConnectionParam("login", "user"),
508+
"password",
509+
DbtConnectionParam("port", default=443),
510+
]
511+
512+
conn_extra_params = [
513+
DbtConnectionParam("type", default="trino"),
514+
DbtConnectionParam("method", default="none"),
515+
DbtConnectionParam("http_scheme", default="https"),
516+
DbtConnectionParam("verify", default=True),
517+
DbtConnectionParam("database", "catalog", "hive"),
518+
"catalog",
519+
"host",
520+
"port",
521+
DbtConnectionParam("user", "user"),
522+
"password",
523+
"schema",
524+
"session_properties",
525+
"http_headers",
526+
"role",
527+
"client_tags",
528+
"source",
529+
]

tests/hooks/test_trino.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Tests for the DbtTrinoHook."""
2+
3+
from airflow.models.connection import Connection
4+
5+
from airflow_dbt_python.hooks.target import DbtConnectionHook, DbtTrinoHook
6+
7+
8+
def test_trino_registration():
9+
"""Ensure that a Trino connection is mapped to a valid dbt profile dict."""
10+
conn = Connection(
11+
conn_id="my_trino",
12+
conn_type="trino",
13+
extra="""
14+
{
15+
"type": "trino",
16+
"method": "ldap",
17+
"user": "user",
18+
"password": "pass",
19+
"host": "trino-etl.example",
20+
"port": 443,
21+
"http_scheme": "https",
22+
"database": "ads_dwh_hdfs",
23+
"schema": "ads_platform_dwh",
24+
"verify": false
25+
}
26+
""",
27+
)
28+
29+
trino_hook = DbtTrinoHook(conn=conn)
30+
details = trino_hook.get_dbt_details_from_connection(conn)
31+
32+
assert details["type"] == "trino"
33+
assert details["method"] == "ldap"
34+
assert details["host"] == "trino-etl.example"
35+
assert details["port"] == 443
36+
assert details["user"] == "user"
37+
assert details["password"] == "pass"
38+
assert details["schema"] == "ads_platform_dwh"
39+
assert details["catalog"] == "ads_dwh_hdfs"
40+
assert details["verify"] is False
41+
42+
43+
def test_trino_catalog_overrides_database():
44+
"""Ensure that explicit catalog overrides database → catalog fallback."""
45+
conn = Connection(
46+
conn_id="my_trino_catalog",
47+
conn_type="trino",
48+
extra="""
49+
{
50+
"type": "trino",
51+
"user": "u",
52+
"password": "p",
53+
"host": "h",
54+
"port": 8443,
55+
"database": "fallback_db",
56+
"catalog": "explicit_catalog",
57+
"schema": "s"
58+
}
59+
""",
60+
)
61+
62+
trino_hook = DbtTrinoHook(conn=conn)
63+
details = trino_hook.get_dbt_details_from_connection(conn)
64+
65+
assert details["catalog"] == "explicit_catalog"

0 commit comments

Comments
 (0)