Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/mkdocs/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ catalog:
rest:
uri: http://rest-catalog/ws/
credential: t-1234:secret

mtls-secured-catalog:
uri: https://rest-catalog/ws/
ssl:
client:
cert: /absolute/path/to/client.crt
key: /absolute/path/to/client.key
cabundle: /absolute/path/to/cabundle.pem
```

Lastly, you can also set it using environment variables:
Expand Down
84 changes: 47 additions & 37 deletions python/pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
Union,
)

import requests
from pydantic import Field, ValidationError
from requests import HTTPError
from requests import HTTPError, Session

from pyiceberg import __version__
from pyiceberg.catalog import (
Expand Down Expand Up @@ -91,6 +90,11 @@ class Endpoints:
TOKEN = "token"
TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange"
SEMICOLON = ":"
KEY = "key"
CERT = "cert"
CLIENT = "client"
CA_BUNDLE = "cabundle"
SSL = "ssl"

NAMESPACE_SEPARATOR = b"\x1F".decode("UTF-8")

Expand Down Expand Up @@ -166,8 +170,9 @@ class OAuthErrorResponse(IcebergBaseModel):


class RestCatalog(Catalog):
token: Optional[str]
uri: str
session: Session
properties: dict

def __init__(
self,
Expand All @@ -184,29 +189,42 @@ def __init__(
"""
self.properties = properties
self.uri = properties[URI]

if credential := properties.get(CREDENTIAL):
properties[TOKEN] = self._fetch_access_token(credential)
self._create_session()
super().__init__(name, **self._fetch_config(properties))

def _create_session(self) -> None:
"""Creates a request session with provided catalog configuration"""

self.session = Session()
# Sets the client side and server side SSL cert verification, if provided as properties.
if ssl_config := self.properties.get(SSL):
if ssl_ca_bundle := ssl_config.get(CA_BUNDLE):
self.session.verify = ssl_ca_bundle
if ssl_client := ssl_config.get(CLIENT):
if all(k in ssl_client for k in (CERT, KEY)):
self.session.cert = (ssl_client[CERT], ssl_client[KEY])
elif ssl_client_cert := ssl_client.get(CERT):
self.session.cert = ssl_client_cert

# Set Auth token for subsequent calls in the session
if token := self.properties.get(TOKEN):
self.session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
elif credential := self.properties.get(CREDENTIAL):
token = self._fetch_access_token(credential)
self.session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"

# Set HTTP headers
self.session.headers["Content-type"] = "application/json"
self.session.headers["X-Client-Version"] = ICEBERG_REST_SPEC_VERSION
self.session.headers["User-Agent"] = f"PyIceberg/{__version__}"

def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier:
"""The identifier should have at least one element"""
identifier_tuple = Catalog.identifier_to_tuple(identifier)
if len(identifier_tuple) < 1:
raise NoSuchNamespaceError(f"Empty namespace identifier: {identifier}")
return identifier_tuple

@property
def headers(self) -> Properties:
headers = {
"Content-type": "application/json",
"X-Client-Version": ICEBERG_REST_SPEC_VERSION,
"User-Agent": f"PyIceberg/{__version__}",
}
if token := self.properties.get("token"):
headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
return headers

def url(self, endpoint: str, prefixed: bool = True, **kwargs) -> str:
"""Constructs the endpoint

Expand Down Expand Up @@ -235,7 +253,7 @@ def _fetch_access_token(self, credential: str) -> str:
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: CATALOG_SCOPE}
url = self.url(Endpoints.get_token, prefixed=False)
# Uses application/x-www-form-urlencoded by default
response = requests.post(url=url, data=data)
response = self.session.post(url=url, data=data)
try:
response.raise_for_status()
except HTTPError as exc:
Expand All @@ -244,7 +262,7 @@ def _fetch_access_token(self, credential: str) -> str:
return TokenResponse(**response.json()).access_token

def _fetch_config(self, properties: Properties) -> Properties:
response = requests.get(self.url(Endpoints.get_config, prefixed=False), headers=self.headers)
response = self.session.get(self.url(Endpoints.get_config, prefixed=False))
try:
response.raise_for_status()
except HTTPError as exc:
Expand Down Expand Up @@ -334,10 +352,9 @@ def create_table(
properties=properties,
)
serialized_json = request.json()
response = requests.post(
response = self.session.post(
self.url(Endpoints.create_table, namespace=namespace_and_table["namespace"]),
data=serialized_json,
headers=self.headers,
)
try:
response.raise_for_status()
Expand All @@ -355,10 +372,7 @@ def create_table(
def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
response = requests.get(
self.url(Endpoints.list_tables, namespace=namespace_concat),
headers=self.headers,
)
response = self.session.get(self.url(Endpoints.list_tables, namespace=namespace_concat))
try:
response.raise_for_status()
except HTTPError as exc:
Expand All @@ -371,9 +385,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
if len(identifier_tuple) <= 1:
raise NoSuchTableError(f"Missing namespace or invalid identifier: {identifier}")

response = requests.get(
self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier)), headers=self.headers
)
response = self.session.get(self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier)))
try:
response.raise_for_status()
except HTTPError as exc:
Expand All @@ -387,9 +399,8 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
)

def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None:
response = requests.delete(
response = self.session.delete(
self.url(Endpoints.drop_table, prefixed=True, purge=purge_requested, **self._split_identifier_for_path(identifier)),
headers=self.headers,
)
try:
response.raise_for_status()
Expand All @@ -404,7 +415,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
"source": self._split_identifier_for_json(from_identifier),
"destination": self._split_identifier_for_json(to_identifier),
}
response = requests.post(self.url(Endpoints.rename_table), json=payload, headers=self.headers)
response = self.session.post(self.url(Endpoints.rename_table), json=payload)
try:
response.raise_for_status()
except HTTPError as exc:
Expand All @@ -413,7 +424,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
payload = {"namespace": namespace_tuple, "properties": properties}
response = requests.post(self.url(Endpoints.create_namespace), json=payload, headers=self.headers)
response = self.session.post(self.url(Endpoints.create_namespace), json=payload)
try:
response.raise_for_status()
except HTTPError as exc:
Expand All @@ -422,21 +433,20 @@ def create_namespace(self, namespace: Union[str, Identifier], properties: Proper
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
response = requests.delete(self.url(Endpoints.drop_namespace, namespace=namespace), headers=self.headers)
response = self.session.delete(self.url(Endpoints.drop_namespace, namespace=namespace))
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {404: NoSuchNamespaceError})

def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
namespace_tuple = self.identifier_to_tuple(namespace)
response = requests.get(
response = self.session.get(
self.url(
f"{Endpoints.list_namespaces}?parent={NAMESPACE_SEPARATOR.join(namespace_tuple)}"
if namespace_tuple
else Endpoints.list_namespaces
),
headers=self.headers,
)
try:
response.raise_for_status()
Expand All @@ -449,7 +459,7 @@ def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identi
def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
response = requests.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace), headers=self.headers)
response = self.session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace))
try:
response.raise_for_status()
except HTTPError as exc:
Expand All @@ -463,7 +473,7 @@ def update_namespace_properties(
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
payload = {"removals": list(removals or []), "updates": updates}
response = requests.post(self.url(Endpoints.update_properties, namespace=namespace), json=payload, headers=self.headers)
response = self.session.post(self.url(Endpoints.update_properties, namespace=namespace), json=payload)
try:
response.raise_for_status()
except HTTPError as exc:
Expand Down
Loading