diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index 9753bab5..b19b11bb 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -11,7 +11,12 @@ # limitations under the License. from trino import constants -from trino.client import get_header_values, get_session_property_values +from trino.client import ( + get_header_values, + get_prepared_statement_values, + get_roles_values, + get_session_property_values, +) def test_get_header_values(): @@ -24,3 +29,21 @@ def test_get_session_property_values(): headers = {constants.HEADER_SET_SESSION: "a=1, b=2, c=more%3Dv1%2Cv2"} values = get_session_property_values(headers, constants.HEADER_SET_SESSION) assert values == [("a", "1"), ("b", "2"), ("c", "more=v1,v2")] + + +def test_get_session_property_values_ignores_empty_values(): + headers = {constants.HEADER_SET_SESSION: ""} + values = get_session_property_values(headers, constants.HEADER_SET_SESSION) + assert len(values) == 0 + + +def test_get_prepared_statement_values_ignores_empty_values(): + headers = {constants.HEADER_SET_SESSION: ""} + values = get_prepared_statement_values(headers, constants.HEADER_SET_SESSION) + assert len(values) == 0 + + +def test_get_roles_values_ignores_empty_values(): + headers = {constants.HEADER_SET_SESSION: ""} + values = get_roles_values(headers, constants.HEADER_SET_SESSION) + assert len(values) == 0 diff --git a/trino/client.py b/trino/client.py index 0073b3e3..18e81c9e 100644 --- a/trino/client.py +++ b/trino/client.py @@ -225,7 +225,7 @@ def get_session_property_values(headers, header): kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) - for k, v in (kv.split("=", 1) for kv in kvs) + for k, v in (kv.split("=", 1) for kv in kvs if kv) ] @@ -233,7 +233,7 @@ def get_prepared_statement_values(headers, header): kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) - for k, v in (kv.split("=", 1) for kv in kvs) + for k, v in (kv.split("=", 1) for kv in kvs if kv) ] @@ -241,7 +241,7 @@ def get_roles_values(headers, header): kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) - for k, v in (kv.split("=", 1) for kv in kvs) + for k, v in (kv.split("=", 1) for kv in kvs if kv) ]