diff --git a/python/mkdocs/docs/index.md b/python/mkdocs/docs/index.md index 35ffee187564..b7597d0f0002 100644 --- a/python/mkdocs/docs/index.md +++ b/python/mkdocs/docs/index.md @@ -91,6 +91,14 @@ catalog: rest: uri: http://rest-catalog/ws/ credential: t-1234:secret + + mtls-secured-catalog: + uri: https://rest-catalog/ws/ + ssl: + client: + cert: /absolute/path/to/client.crt + key: /absolute/path/to/client.key + cabundle: /absolute/path/to/cabundle.pem ``` Lastly, you can also set it using environment variables: diff --git a/python/pyiceberg/catalog/rest.py b/python/pyiceberg/catalog/rest.py index 90db6691f946..3caf091660f2 100644 --- a/python/pyiceberg/catalog/rest.py +++ b/python/pyiceberg/catalog/rest.py @@ -25,9 +25,8 @@ Union, ) -import requests from pydantic import Field, ValidationError -from requests import HTTPError +from requests import HTTPError, Session from pyiceberg import __version__ from pyiceberg.catalog import ( @@ -91,6 +90,11 @@ class Endpoints: TOKEN = "token" TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange" SEMICOLON = ":" +KEY = "key" +CERT = "cert" +CLIENT = "client" +CA_BUNDLE = "cabundle" +SSL = "ssl" NAMESPACE_SEPARATOR = b"\x1F".decode("UTF-8") @@ -166,8 +170,9 @@ class OAuthErrorResponse(IcebergBaseModel): class RestCatalog(Catalog): - token: Optional[str] uri: str + session: Session + properties: dict def __init__( self, @@ -184,11 +189,35 @@ def __init__( """ self.properties = properties self.uri = properties[URI] - - if credential := properties.get(CREDENTIAL): - properties[TOKEN] = self._fetch_access_token(credential) + self._create_session() super().__init__(name, **self._fetch_config(properties)) + def _create_session(self) -> None: + """Creates a request session with provided catalog configuration""" + + self.session = Session() + # Sets the client side and server side SSL cert verification, if provided as properties. + if ssl_config := self.properties.get(SSL): + if ssl_ca_bundle := ssl_config.get(CA_BUNDLE): + self.session.verify = ssl_ca_bundle + if ssl_client := ssl_config.get(CLIENT): + if all(k in ssl_client for k in (CERT, KEY)): + self.session.cert = (ssl_client[CERT], ssl_client[KEY]) + elif ssl_client_cert := ssl_client.get(CERT): + self.session.cert = ssl_client_cert + + # Set Auth token for subsequent calls in the session + if token := self.properties.get(TOKEN): + self.session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" + elif credential := self.properties.get(CREDENTIAL): + token = self._fetch_access_token(credential) + self.session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" + + # Set HTTP headers + self.session.headers["Content-type"] = "application/json" + self.session.headers["X-Client-Version"] = ICEBERG_REST_SPEC_VERSION + self.session.headers["User-Agent"] = f"PyIceberg/{__version__}" + def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier: """The identifier should have at least one element""" identifier_tuple = Catalog.identifier_to_tuple(identifier) @@ -196,17 +225,6 @@ def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) raise NoSuchNamespaceError(f"Empty namespace identifier: {identifier}") return identifier_tuple - @property - def headers(self) -> Properties: - headers = { - "Content-type": "application/json", - "X-Client-Version": ICEBERG_REST_SPEC_VERSION, - "User-Agent": f"PyIceberg/{__version__}", - } - if token := self.properties.get("token"): - headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}" - return headers - def url(self, endpoint: str, prefixed: bool = True, **kwargs) -> str: """Constructs the endpoint @@ -235,7 +253,7 @@ def _fetch_access_token(self, credential: str) -> str: data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: CATALOG_SCOPE} url = self.url(Endpoints.get_token, prefixed=False) # Uses application/x-www-form-urlencoded by default - response = requests.post(url=url, data=data) + response = self.session.post(url=url, data=data) try: response.raise_for_status() except HTTPError as exc: @@ -244,7 +262,7 @@ def _fetch_access_token(self, credential: str) -> str: return TokenResponse(**response.json()).access_token def _fetch_config(self, properties: Properties) -> Properties: - response = requests.get(self.url(Endpoints.get_config, prefixed=False), headers=self.headers) + response = self.session.get(self.url(Endpoints.get_config, prefixed=False)) try: response.raise_for_status() except HTTPError as exc: @@ -334,10 +352,9 @@ def create_table( properties=properties, ) serialized_json = request.json() - response = requests.post( + response = self.session.post( self.url(Endpoints.create_table, namespace=namespace_and_table["namespace"]), data=serialized_json, - headers=self.headers, ) try: response.raise_for_status() @@ -355,10 +372,7 @@ def create_table( def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple) - response = requests.get( - self.url(Endpoints.list_tables, namespace=namespace_concat), - headers=self.headers, - ) + response = self.session.get(self.url(Endpoints.list_tables, namespace=namespace_concat)) try: response.raise_for_status() except HTTPError as exc: @@ -371,9 +385,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: if len(identifier_tuple) <= 1: raise NoSuchTableError(f"Missing namespace or invalid identifier: {identifier}") - response = requests.get( - self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier)), headers=self.headers - ) + response = self.session.get(self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier))) try: response.raise_for_status() except HTTPError as exc: @@ -387,9 +399,8 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table: ) def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None: - response = requests.delete( + response = self.session.delete( self.url(Endpoints.drop_table, prefixed=True, purge=purge_requested, **self._split_identifier_for_path(identifier)), - headers=self.headers, ) try: response.raise_for_status() @@ -404,7 +415,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U "source": self._split_identifier_for_json(from_identifier), "destination": self._split_identifier_for_json(to_identifier), } - response = requests.post(self.url(Endpoints.rename_table), json=payload, headers=self.headers) + response = self.session.post(self.url(Endpoints.rename_table), json=payload) try: response.raise_for_status() except HTTPError as exc: @@ -413,7 +424,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: namespace_tuple = self._check_valid_namespace_identifier(namespace) payload = {"namespace": namespace_tuple, "properties": properties} - response = requests.post(self.url(Endpoints.create_namespace), json=payload, headers=self.headers) + response = self.session.post(self.url(Endpoints.create_namespace), json=payload) try: response.raise_for_status() except HTTPError as exc: @@ -422,7 +433,7 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper def drop_namespace(self, namespace: Union[str, Identifier]) -> None: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) - response = requests.delete(self.url(Endpoints.drop_namespace, namespace=namespace), headers=self.headers) + response = self.session.delete(self.url(Endpoints.drop_namespace, namespace=namespace)) try: response.raise_for_status() except HTTPError as exc: @@ -430,13 +441,12 @@ def drop_namespace(self, namespace: Union[str, Identifier]) -> None: def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]: namespace_tuple = self.identifier_to_tuple(namespace) - response = requests.get( + response = self.session.get( self.url( f"{Endpoints.list_namespaces}?parent={NAMESPACE_SEPARATOR.join(namespace_tuple)}" if namespace_tuple else Endpoints.list_namespaces ), - headers=self.headers, ) try: response.raise_for_status() @@ -449,7 +459,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties: namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) - response = requests.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace), headers=self.headers) + response = self.session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace)) try: response.raise_for_status() except HTTPError as exc: @@ -463,7 +473,7 @@ def update_namespace_properties( namespace_tuple = self._check_valid_namespace_identifier(namespace) namespace = NAMESPACE_SEPARATOR.join(namespace_tuple) payload = {"removals": list(removals or []), "updates": updates} - response = requests.post(self.url(Endpoints.update_properties, namespace=namespace), json=payload, headers=self.headers) + response = self.session.post(self.url(Endpoints.update_properties, namespace=namespace), json=payload) try: response.raise_for_status() except HTTPError as exc: diff --git a/python/tests/catalog/test_rest.py b/python/tests/catalog/test_rest.py index e91d8e3674f4..8779f3b9473a 100644 --- a/python/tests/catalog/test_rest.py +++ b/python/tests/catalog/test_rest.py @@ -20,6 +20,7 @@ import pytest from requests_mock import Mocker +import pyiceberg from pyiceberg.catalog import PropertiesUpdateSummary, Table from pyiceberg.catalog.rest import RestCatalog from pyiceberg.exceptions import ( @@ -46,6 +47,15 @@ TEST_URI = "https://iceberg-test-catalog/" TEST_CREDENTIALS = "client:secret" TEST_TOKEN = "some_jwt_token" +TEST_HEADERS = { + "Content-type": "application/json", + "X-Client-Version": "0.14.1", + "User-Agent": f"PyIceberg/{pyiceberg.__version__}", + "Authorization": f"Bearer {TEST_TOKEN}", +} +OAUTH_TEST_HEADERS = { + "Content-type": "application/x-www-form-urlencoded", +} @pytest.fixture @@ -77,8 +87,11 @@ def test_token_200(rest_mock: Mocker): "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", }, status_code=200, + request_headers=OAUTH_TEST_HEADERS, + ) + assert ( + RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS).session.headers["Authorization"] == f"Bearer {TEST_TOKEN}" ) - assert RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS).properties["token"] == TEST_TOKEN def test_token_400(rest_mock: Mocker): @@ -86,6 +99,7 @@ def test_token_400(rest_mock: Mocker): f"{TEST_URI}v1/oauth/tokens", json={"error": "invalid_client", "error_description": "Credentials for key invalid_key do not match"}, status_code=400, + request_headers=OAUTH_TEST_HEADERS, ) with pytest.raises(OAuthError) as e: @@ -99,6 +113,7 @@ def test_token_401(rest_mock: Mocker): f"{TEST_URI}v1/oauth/tokens", json={"error": "invalid_client", "error_description": "Unknown or invalid client"}, status_code=401, + request_headers=OAUTH_TEST_HEADERS, ) with pytest.raises(OAuthError) as e: @@ -112,6 +127,7 @@ def test_list_tables_200(rest_mock: Mocker): f"{TEST_URI}v1/namespaces/{namespace}/tables", json={"identifiers": [{"namespace": ["examples"], "name": "fooshare"}]}, status_code=200, + request_headers=TEST_HEADERS, ) assert RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).list_tables(namespace) == [("examples", "fooshare")] @@ -129,6 +145,7 @@ def test_list_tables_404(rest_mock: Mocker): } }, status_code=404, + request_headers=TEST_HEADERS, ) with pytest.raises(NoSuchNamespaceError) as e: RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).list_tables(namespace) @@ -140,6 +157,7 @@ def test_list_namespaces_200(rest_mock: Mocker): f"{TEST_URI}v1/namespaces", json={"namespaces": [["default"], ["examples"], ["fokko"], ["system"]]}, status_code=200, + request_headers=TEST_HEADERS, ) assert RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).list_namespaces() == [ ("default",), @@ -154,6 +172,7 @@ def test_list_namespace_with_parent_200(rest_mock: Mocker): f"{TEST_URI}v1/namespaces?parent=accounting", json={"namespaces": [["tax"]]}, status_code=200, + request_headers=TEST_HEADERS, ) assert RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).list_namespaces(("accounting",)) == [ ("accounting", "tax"), @@ -166,6 +185,7 @@ def test_create_namespace_200(rest_mock: Mocker): f"{TEST_URI}v1/namespaces", json={"namespace": [namespace], "properties": {}}, status_code=200, + request_headers=TEST_HEADERS, ) RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).create_namespace(namespace) @@ -182,6 +202,7 @@ def test_create_namespace_409(rest_mock: Mocker): } }, status_code=409, + request_headers=TEST_HEADERS, ) with pytest.raises(NamespaceAlreadyExistsError) as e: RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).create_namespace(namespace) @@ -200,6 +221,7 @@ def test_drop_namespace_404(rest_mock: Mocker): } }, status_code=404, + request_headers=TEST_HEADERS, ) with pytest.raises(NoSuchNamespaceError) as e: RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).drop_namespace(namespace) @@ -212,6 +234,7 @@ def test_load_namespace_properties_200(rest_mock: Mocker): f"{TEST_URI}v1/namespaces/{namespace}", json={"namespace": ["fokko"], "properties": {"prop": "yes"}}, status_code=204, + request_headers=TEST_HEADERS, ) assert RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).load_namespace_properties(namespace) == {"prop": "yes"} @@ -228,6 +251,7 @@ def test_load_namespace_properties_404(rest_mock: Mocker): } }, status_code=404, + request_headers=TEST_HEADERS, ) with pytest.raises(NoSuchNamespaceError) as e: RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).load_namespace_properties(namespace) @@ -239,6 +263,7 @@ def test_update_namespace_properties_200(rest_mock: Mocker): f"{TEST_URI}v1/namespaces/fokko/properties", json={"removed": [], "updated": ["prop"], "missing": ["abc"]}, status_code=200, + request_headers=TEST_HEADERS, ) response = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).update_namespace_properties( ("fokko",), {"abc"}, {"prop": "yes"} @@ -258,6 +283,7 @@ def test_update_namespace_properties_404(rest_mock: Mocker): } }, status_code=404, + request_headers=TEST_HEADERS, ) with pytest.raises(NoSuchNamespaceError) as e: RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).update_namespace_properties(("fokko",), {"abc"}, {"prop": "yes"}) @@ -336,6 +362,7 @@ def test_load_table_200(rest_mock: Mocker): "config": {"client.factory": "io.tabular.iceberg.catalog.TabularAwsClientFactory", "region": "us-west-2"}, }, status_code=200, + request_headers=TEST_HEADERS, ) actual = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).load_table(("fokko", "table")) expected = Table( @@ -428,6 +455,7 @@ def test_load_table_404(rest_mock: Mocker): } }, status_code=404, + request_headers=TEST_HEADERS, ) with pytest.raises(NoSuchTableError) as e: @@ -446,6 +474,7 @@ def test_drop_table_404(rest_mock: Mocker): } }, status_code=404, + request_headers=TEST_HEADERS, ) with pytest.raises(NoSuchTableError) as e: @@ -510,6 +539,7 @@ def test_create_table_200(rest_mock: Mocker, table_schema_simple: Schema): }, }, status_code=200, + request_headers=TEST_HEADERS, ) table = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).create_table( identifier=("fokko", "fokko2"), @@ -578,6 +608,7 @@ def test_create_table_409(rest_mock, table_schema_simple: Schema): } }, status_code=409, + request_headers=TEST_HEADERS, ) with pytest.raises(TableAlreadyExistsError) as e: @@ -600,6 +631,7 @@ def test_delete_namespace_204(rest_mock: Mocker): f"{TEST_URI}v1/namespaces/{namespace}", json={}, status_code=204, + request_headers=TEST_HEADERS, ) RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).drop_namespace(namespace) @@ -609,6 +641,7 @@ def test_delete_table_204(rest_mock: Mocker): f"{TEST_URI}v1/namespaces/example/tables/fokko", json={}, status_code=204, + request_headers=TEST_HEADERS, ) RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).drop_table(("example", "fokko")) @@ -624,6 +657,7 @@ def test_delete_table_404(rest_mock: Mocker): } }, status_code=404, + request_headers=TEST_HEADERS, ) with pytest.raises(NoSuchTableError) as e: RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).drop_table(("example", "fokko")) @@ -688,3 +722,36 @@ def test_update_namespace_properties_invalid_namespace(rest_mock: Mocker): # Missing namespace RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).update_namespace_properties(()) assert "Empty namespace identifier" in str(e.value) + + +def test_request_session_with_ssl_ca_bundle(): + # Given + catalog_properties = { + "uri": TEST_URI, + "token": TEST_TOKEN, + "ssl": { + "cabundle": "path_to_ca_bundle", + }, + } + with pytest.raises(OSError) as e: + # Missing namespace + RestCatalog("rest", **catalog_properties) + assert "Could not find a suitable TLS CA certificate bundle, invalid path: path_to_ca_bundle" in str(e.value) + + +def test_request_session_with_ssl_client_cert(): + # Given + catalog_properties = { + "uri": TEST_URI, + "token": TEST_TOKEN, + "ssl": { + "client": { + "cert": "path_to_client_cert", + "key": "path_to_client_key", + } + }, + } + with pytest.raises(OSError) as e: + # Missing namespace + RestCatalog("rest", **catalog_properties) + assert "Could not find the TLS certificate file, invalid path: path_to_client_cert" in str(e.value)