diff --git a/README.md b/README.md index cab2188a..b53e2b4f 100644 --- a/README.md +++ b/README.md @@ -100,9 +100,14 @@ Attributes can also be passed in the connection string. ```python from sqlalchemy import create_engine +from trino.sqlalchemy import URL engine = create_engine( - 'trino://user@localhost:8080/system', + URL( + host="localhost", + port=8080, + catalog="system" + ), connect_args={ "session_properties": {'query_max_run_time': '1d'}, "client_tags": ["tag1", "tag2"], @@ -119,6 +124,14 @@ engine = create_engine( '&experimental_python_types=true' '&roles={"catalog1": "role1"}' ) + +# or using the URL factory method +engine = create_engine(URL( + host="localhost", + port=8080, + client_tags=["tag1", "tag2"], + experimental_python_types=True +)) ``` ## Authentication mechanisms diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index 963b27f7..fde3a9f1 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -9,6 +9,7 @@ from trino.dbapi import Connection from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect from trino.transaction import IsolationLevel +from trino.sqlalchemy import URL as trino_url class TestTrinoDialect: @@ -16,20 +17,35 @@ def setup(self): self.dialect = TrinoDialect() @pytest.mark.parametrize( - "url, expected_args, expected_kwargs", + "url, generated_url, expected_args, expected_kwargs", [ ( - make_url("trino://user@localhost"), + make_url(trino_url( + user="user", + host="localhost", + )), + 'trino://user@localhost:8080?source=trino-sqlalchemy', list(), - dict(host="localhost", catalog="system", user="user", source="trino-sqlalchemy"), + dict(host="localhost", catalog="system", user="user", port=8080, source="trino-sqlalchemy"), ), ( - make_url("trino://user@localhost:8080"), + make_url(trino_url( + user="user", + host="localhost", + port=443, + )), + 'trino://user@localhost:443?source=trino-sqlalchemy', list(), - dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"), + dict(host="localhost", port=443, catalog="system", user="user", source="trino-sqlalchemy"), ), ( - make_url("trino://user:pass@localhost:8080?source=trino-rulez"), + make_url(trino_url( + user="user", + password="pass", + host="localhost", + source="trino-rulez", + )), + 'trino://user:***@localhost:8080?source=trino-rulez', list(), dict( host="localhost", @@ -42,13 +58,64 @@ def setup(self): ), ), ( - make_url( - 'trino://user@localhost:8080?' - 'session_properties={"query_max_run_time": "1d"}' - '&http_headers={"trino": 1}' - '&extra_credential=[("a", "b"), ("c", "d")]' - '&client_tags=[1, "sql"]' - '&experimental_python_types=true'), + make_url(trino_url( + user="user", + host="localhost", + cert="/my/path/to/cert", + key="afdlsdfk%4#'", + )), + 'trino://user@localhost:8080' + '?cert=%2Fmy%2Fpath%2Fto%2Fcert' + '&key=afdlsdfk%254%23%27' + '&source=trino-sqlalchemy', + list(), + dict( + host="localhost", + port=8080, + catalog="system", + user="user", + auth=CertificateAuthentication("/my/path/to/cert", "afdlsdfk%4#'"), + http_scheme="https", + source="trino-sqlalchemy" + ), + ), + ( + make_url(trino_url( + user="user", + host="localhost", + access_token="afdlsdfk%4#'", + )), + 'trino://user@localhost:8080' + '?access_token=afdlsdfk%254%23%27' + '&source=trino-sqlalchemy', + list(), + dict( + host="localhost", + port=8080, + catalog="system", + user="user", + auth=JWTAuthentication("afdlsdfk%4#'"), + http_scheme="https", + source="trino-sqlalchemy" + ), + ), + ( + make_url(trino_url( + user="user", + host="localhost", + session_properties={"query_max_run_time": "1d"}, + http_headers={"trino": 1}, + extra_credential=[("a", "b"), ("c", "d")], + client_tags=["1", "sql"], + experimental_python_types=True, + )), + 'trino://user@localhost:8080' + '?client_tags=%5B%221%22%2C+%22sql%22%5D' + '&experimental_python_types=true' + '&extra_credential=%5B%5B%22a%22%2C+%22b%22%5D%2C+%5B%22c%22%2C+%22d%22%5D%5D' + '&http_headers=%7B%22trino%22%3A+1%7D' + '&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D' + '&source=trino-sqlalchemy', list(), dict( host="localhost", @@ -59,23 +126,87 @@ def setup(self): session_properties={"query_max_run_time": "1d"}, http_headers={"trino": 1}, extra_credential=[("a", "b"), ("c", "d")], - client_tags=[1, "sql"], + client_tags=["1", "sql"], experimental_python_types=True, ), ), + # url encoding ( - make_url('trino://user@localhost:8080?roles={"hive":"finance","system":"analyst"}'), + make_url(trino_url( + user="user@test.org/my_role", + password="pass /*&", + host="localhost", + session_properties={"query_max_run_time": "1d"}, + http_headers={"trino": 1}, + extra_credential=[ + ("user1@test.org/my_role", "user2@test.org/my_role"), + ("user3@test.org/my_role", "user36@test.org/my_role")], + experimental_python_types=True, + client_tags=["1 @& /\"", "sql"], + verify=False, + )), + 'trino://user%40test.org%2Fmy_role:***@localhost:8080' + '?client_tags=%5B%221+%40%26+%2F%5C%22%22%2C+%22sql%22%5D' + '&experimental_python_types=true' + '&extra_credential=%5B%5B%22user1%40test.org%2Fmy_role%22%2C+' + '%22user2%40test.org%2Fmy_role%22%5D%2C+' + '%5B%22user3%40test.org%2Fmy_role%22%2C+' + '%22user36%40test.org%2Fmy_role%22%5D%5D' + '&http_headers=%7B%22trino%22%3A+1%7D' + '&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D' + '&source=trino-sqlalchemy' + '&verify=false', list(), - dict(host="localhost", - port=8080, - catalog="system", - user="user", - roles={"hive": "finance", "system": "analyst"}, - source="trino-sqlalchemy"), + dict( + host="localhost", + port=8080, + catalog="system", + user="user@test.org/my_role", + auth=BasicAuthentication("user@test.org/my_role", "pass /*&"), + http_scheme="https", + source="trino-sqlalchemy", + session_properties={"query_max_run_time": "1d"}, + http_headers={"trino": 1}, + extra_credential=[ + ("user1@test.org/my_role", "user2@test.org/my_role"), + ("user3@test.org/my_role", "user36@test.org/my_role")], + experimental_python_types=True, + client_tags=["1 @& /\"", "sql"], + verify=False, + ), + ), + ( + make_url(trino_url( + user="user", + host="localhost", + roles={ + "hive": "finance", + "system": "analyst", + } + )), + 'trino://user@localhost:8080' + '?roles=%7B%22hive%22%3A+%22finance%22%2C+%22system%22%3A+%22analyst%22%7D&source=trino-sqlalchemy', + list(), + dict( + host="localhost", + port=8080, + catalog="system", + user="user", + roles={"hive": "finance", "system": "analyst"}, + source="trino-sqlalchemy", + ), ), ], ) - def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]): + def test_create_connect_args( + self, + url: URL, + generated_url: str, + expected_args: List[Any], + expected_kwargs: Dict[str, Any] + ): + assert repr(url) == generated_url + actual_args, actual_kwargs = self.dialect.create_connect_args(url) assert actual_args == expected_args diff --git a/trino/sqlalchemy/__init__.py b/trino/sqlalchemy/__init__.py index 000d3e08..3c10f0b8 100644 --- a/trino/sqlalchemy/__init__.py +++ b/trino/sqlalchemy/__init__.py @@ -10,5 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from sqlalchemy.dialects import registry +from .util import _url as URL # noqa registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect") diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index c5bccf37..0bd7938b 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -10,9 +10,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from ast import literal_eval from textwrap import dedent from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +from urllib.parse import unquote_plus from sqlalchemy import exc, sql from sqlalchemy.engine.base import Connection @@ -80,49 +80,54 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any db_parts = (url.database or "system").split("/") if len(db_parts) == 1: - kwargs["catalog"] = db_parts[0] + kwargs["catalog"] = unquote_plus(db_parts[0]) elif len(db_parts) == 2: - kwargs["catalog"] = db_parts[0] - kwargs["schema"] = db_parts[1] + kwargs["catalog"] = unquote_plus(db_parts[0]) + kwargs["schema"] = unquote_plus(db_parts[1]) else: raise ValueError(f"Unexpected database format {url.database}") if url.username: - kwargs["user"] = url.username + kwargs["user"] = unquote_plus(url.username) if url.password: if not url.username: raise ValueError("Username is required when specify password in connection URL") kwargs["http_scheme"] = "https" - kwargs["auth"] = BasicAuthentication(url.username, url.password) + kwargs["auth"] = BasicAuthentication(unquote_plus(url.username), unquote_plus(url.password)) if "access_token" in url.query: kwargs["http_scheme"] = "https" - kwargs["auth"] = JWTAuthentication(url.query["access_token"]) + kwargs["auth"] = JWTAuthentication(unquote_plus(url.query["access_token"])) if "cert" and "key" in url.query: kwargs["http_scheme"] = "https" - kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key']) + kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key'])) if "source" in url.query: - kwargs["source"] = url.query["source"] + kwargs["source"] = unquote_plus(url.query["source"]) else: kwargs["source"] = "trino-sqlalchemy" if "session_properties" in url.query: - kwargs["session_properties"] = json.loads(url.query["session_properties"]) + kwargs["session_properties"] = json.loads(unquote_plus(url.query["session_properties"])) if "http_headers" in url.query: - kwargs["http_headers"] = json.loads(url.query["http_headers"]) + kwargs["http_headers"] = json.loads(unquote_plus(url.query["http_headers"])) if "extra_credential" in url.query: - kwargs["extra_credential"] = literal_eval(url.query["extra_credential"]) + kwargs["extra_credential"] = [ + tuple(extra_credential) for extra_credential in json.loads(unquote_plus(url.query["extra_credential"])) + ] if "client_tags" in url.query: - kwargs["client_tags"] = json.loads(url.query["client_tags"]) + kwargs["client_tags"] = json.loads(unquote_plus(url.query["client_tags"])) if "experimental_python_types" in url.query: - kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"]) + kwargs["experimental_python_types"] = json.loads(unquote_plus(url.query["experimental_python_types"])) + + if "verify" in url.query: + kwargs["verify"] = json.loads(unquote_plus(url.query["verify"])) if "roles" in url.query: kwargs["roles"] = json.loads(url.query["roles"]) diff --git a/trino/sqlalchemy/util.py b/trino/sqlalchemy/util.py new file mode 100644 index 00000000..67bd711e --- /dev/null +++ b/trino/sqlalchemy/util.py @@ -0,0 +1,99 @@ +import json +from urllib.parse import quote_plus + +from typing import Optional, Dict, List, Union, Tuple +from sqlalchemy import exc +from sqlalchemy.engine.url import _rfc_1738_quote # noqa + + +def _url( + host: str, + port: Optional[int] = 8080, + user: Optional[str] = None, + password: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + source: Optional[str] = "trino-sqlalchemy", + session_properties: Dict[str, str] = None, + http_headers: Dict[str, Union[str, int]] = None, + extra_credential: Optional[List[Tuple[str, str]]] = None, + client_tags: Optional[List[str]] = None, + experimental_python_types: Optional[bool] = None, + access_token: Optional[str] = None, + cert: Optional[str] = None, + key: Optional[str] = None, + verify: Optional[bool] = None, + roles: Optional[Dict[str, str]] = None +) -> str: + """ + Composes a SQLAlchemy connection string from the given database connection + parameters. + Parameters containing special characters (e.g., '@', '%') need to be encoded to be parsed correctly. + """ + + trino_url = "trino://" + + if user is not None: + trino_url += _rfc_1738_quote(user) + + if password is not None: + if user is None: + raise exc.ArgumentError("user must be specified when specifying a password.") + trino_url += f":{_rfc_1738_quote(password)}" + + if user is not None: + trino_url += "@" + + if not host: + raise exc.ArgumentError("host must be specified.") + + trino_url += host + + if not port: + raise exc.ArgumentError("port must be specified.") + + trino_url += f":{port}" + + if catalog is not None: + trino_url += f"/{quote_plus(catalog)}" + + if schema is not None: + if catalog is None: + raise exc.ArgumentError("catalog must be specified when specifying a default schema.") + trino_url += f"/{quote_plus(schema)}" + + assert source + trino_url += f"?source={quote_plus(source)}" + + if session_properties is not None: + trino_url += f"&session_properties={quote_plus(json.dumps(session_properties))}" + + if http_headers is not None: + trino_url += f"&http_headers={quote_plus(json.dumps(http_headers))}" + + if extra_credential is not None: + # repr is used here as json.dumps converts tuples into arrays + trino_url += f"&extra_credential={quote_plus(json.dumps(extra_credential))}" + + if client_tags is not None: + trino_url += f"&client_tags={quote_plus(json.dumps(client_tags))}" + + if experimental_python_types is not None: + trino_url += f"&experimental_python_types={json.dumps(experimental_python_types)}" + + if access_token is not None: + trino_url += f"&access_token={quote_plus(access_token)}" + + if cert is not None: + trino_url += f"&cert={quote_plus(cert)}" + + if key is not None: + trino_url += f"&key={quote_plus(key)}" + + if verify is not None: + trino_url += f"&verify={json.dumps(verify)}" + + if roles is not None: + trino_url += f"&roles={quote_plus(json.dumps(roles))}" + + return trino_url