Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion tests/unit/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
# limitations under the License.

from trino import constants
from trino.client import get_header_values, get_session_property_values
from trino.client import (
get_header_values,
get_prepared_statement_values,
get_roles_values,
get_session_property_values,
)


def test_get_header_values():
Expand All @@ -24,3 +29,21 @@ def test_get_session_property_values():
headers = {constants.HEADER_SET_SESSION: "a=1, b=2, c=more%3Dv1%2Cv2"}
values = get_session_property_values(headers, constants.HEADER_SET_SESSION)
assert values == [("a", "1"), ("b", "2"), ("c", "more=v1,v2")]


def test_get_session_property_values_ignores_empty_values():
headers = {constants.HEADER_SET_SESSION: ""}
values = get_session_property_values(headers, constants.HEADER_SET_SESSION)
assert len(values) == 0


def test_get_prepared_statement_values_ignores_empty_values():
headers = {constants.HEADER_SET_SESSION: ""}
values = get_prepared_statement_values(headers, constants.HEADER_SET_SESSION)
assert len(values) == 0


def test_get_roles_values_ignores_empty_values():
headers = {constants.HEADER_SET_SESSION: ""}
values = get_roles_values(headers, constants.HEADER_SET_SESSION)
assert len(values) == 0
6 changes: 3 additions & 3 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,23 +225,23 @@ def get_session_property_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs)
for k, v in (kv.split("=", 1) for kv in kvs if kv)
]


def get_prepared_statement_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs)
for k, v in (kv.split("=", 1) for kv in kvs if kv)
]


def get_roles_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs)
for k, v in (kv.split("=", 1) for kv in kvs if kv)
]


Expand Down