Skip to content

Commit

Permalink
fix: can instantiate connectors without callbacks/secrets-keeper (TCT…
Browse files Browse the repository at this point in the history
…C-10081) (#1894)

* fix: can instantiate connectors without callbacks/secrets-keeper

* doc: better docstring

* fix: dont overload pydantic init and correct namings

* fix: encapsulate secret keeper
  • Loading branch information
julien-pinchelimouroux authored Feb 6, 2025
1 parent 6070819 commit 4313dc4
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 20 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

### Fixed

- Oauth2 & GoogleSheets: can now instantiate connectors without providing secret keeper or callback functions.

## [7.7.3] 2025-01-29

### Changed
Expand Down
20 changes: 20 additions & 0 deletions tests/google_sheets/test_google_sheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pandas as pd
import pytest
from googleapiclient.http import HttpMock
from pandas import DataFrame
from pandas.testing import assert_frame_equal
Expand All @@ -16,6 +17,7 @@
from toucan_connectors.google_sheets.google_sheets_connector import (
GoogleSheetsConnector,
GoogleSheetsDataSource,
GoogleSheetsInvalidConfiguration,
parse_cell_value,
serial_number_to_date,
)
Expand Down Expand Up @@ -328,3 +330,21 @@ def test_default_format():
value = 44303 # Example serial number representing a date
result = parse_cell_value(value)
assert result == value


def test_can_instantiate_without_retrieve_token_callback():
gsheet_connector = GoogleSheetsConnector(
name="test_connector",
auth_id="test_auth_id",
)
assert gsheet_connector.auth_id.get_secret_value() == "test_auth_id"
assert gsheet_connector.retrieve_token is None


def test_raises_when_trying_to_retrieve_token_if_callable_missing():
gsheet_connector = GoogleSheetsConnector(
name="test_connector",
auth_id="test_auth_id",
)
with pytest.raises(GoogleSheetsInvalidConfiguration):
gsheet_connector.get_status()
30 changes: 30 additions & 0 deletions tests/oauth2_connector/test_oauth2connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NoOAuth2RefreshToken,
OAuth2Connector,
OAuth2ConnectorConfig,
SecretKeeperMissingError,
)
from toucan_connectors.snowflake_oauth2.snowflake_oauth2_connector import SnowflakeoAuth2Connector
from toucan_connectors.toucan_connector import get_oauth2_configuration
Expand All @@ -33,6 +34,18 @@ def oauth2_connector(secrets_keeper):
)


@pytest.fixture
def oauth2_connector_without_secret_keeper():
return OAuth2Connector(
auth_flow_id="test",
authorization_url=FAKE_AUTHORIZATION_URL,
scope=SCOPE,
config=OAuth2ConnectorConfig(client_id="", client_secret=""),
redirect_uri="",
token_url=FAKE_TOKEN_URL,
)


def test_build_authorization_url(mocker, oauth2_connector, secrets_keeper):
"""
It should return the authorization URL
Expand Down Expand Up @@ -234,3 +247,20 @@ def test_get_refresh_token(mocker, oauth2_connector):
mocked_load.return_value = {"refresh_token": "bla"}
token = oauth2_connector.get_refresh_token()
assert token == "bla"


def test_raises_exception_if_secret_keeper_not_set(oauth2_connector_without_secret_keeper: OAuth2Connector):
with pytest.raises(SecretKeeperMissingError):
oauth2_connector_without_secret_keeper.get_access_token()

with pytest.raises(SecretKeeperMissingError):
oauth2_connector_without_secret_keeper.retrieve_tokens(authorization_response="")

with pytest.raises(SecretKeeperMissingError):
oauth2_connector_without_secret_keeper.build_authorization_url()

with pytest.raises(SecretKeeperMissingError):
oauth2_connector_without_secret_keeper.get_access_data()

with pytest.raises(SecretKeeperMissingError):
oauth2_connector_without_secret_keeper.get_refresh_token()
41 changes: 32 additions & 9 deletions toucan_connectors/google_sheets/google_sheets_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
getLogger(__name__).warning(f"Missing dependencies for {__name__}: {exc}")
CONNECTOR_OK = False

from pydantic import Field, PrivateAttr, create_model
from pydantic import Field, create_model
from pydantic.json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode

from toucan_connectors.common import ConnectorStatus
from toucan_connectors.common import UI_HIDDEN, ConnectorStatus
from toucan_connectors.toucan_connector import (
PlainJsonSecretStr,
ToucanConnector,
Expand All @@ -29,6 +29,18 @@
)


class GoogleSheetsError(Exception):
"""Base class for Google Sheets errors"""


class GoogleSheetsRetrieveTokenError(GoogleSheetsError):
"""Raised when an error occurs while retrieving access token"""


class GoogleSheetsInvalidConfiguration(GoogleSheetsError):
"""Raised when the connector configuration is invalid"""


class GoogleSheetsDataSource(ToucanDataSource):
domain: str = Field(
...,
Expand Down Expand Up @@ -88,17 +100,28 @@ class GoogleSheetsConnector(ToucanConnector, data_source_model=GoogleSheetsDataS
_auth_flow = "managed_oauth2"
_managed_oauth_service_id = "google-sheets"
_oauth_trigger = "retrieve_token"
_retrieve_token: Callable[[str, str], str] = PrivateAttr()

retrieve_token: Callable[[str, str], str] | None = Field(None, **UI_HIDDEN)
auth_id: PlainJsonSecretStr = None

def __init__(self, retrieve_token: Callable[[str, str], str], *args, **kwargs):
super().__init__(**kwargs)
self._retrieve_token = retrieve_token # Could be async
def _call_retrieve_token(self) -> str:
"""Retrieves the access token for Google Sheets
Raises a GoogleSheetsInvalidConfiguration if retrieve_token callback is not set
Raises a GoogleSheetsRetrieveTokenError if an error is encountered while retrieving the token
"""
if self.retrieve_token is None:
raise GoogleSheetsInvalidConfiguration(
"Retrieve token callback function is not configured. Please provide it at instantiation."
)
try:
return self.retrieve_token(self._managed_oauth_service_id, self.auth_id.get_secret_value())
except Exception as exc:
raise GoogleSheetsRetrieveTokenError(str(exc)) from exc

def _google_client_build_kwargs(self): # pragma: no cover
# Override it for testing purposes
access_token = self._retrieve_token(self._managed_oauth_service_id, self.auth_id.get_secret_value())
access_token = self._call_retrieve_token()
return {"credentials": Credentials(token=access_token)}

def _google_client_request_kwargs(self): # pragma: no cover
Expand Down Expand Up @@ -138,8 +161,8 @@ def get_status(self) -> ConnectorStatus:
If successful, returns a message with the email of the connected user account.
"""
try:
access_token = self._retrieve_token(self._managed_oauth_service_id, self.auth_id.get_secret_value())
except Exception:
access_token = self._call_retrieve_token()
except GoogleSheetsRetrieveTokenError:
return ConnectorStatus(status=False, error="Credentials are missing")

if not access_token:
Expand Down
31 changes: 20 additions & 11 deletions toucan_connectors/oauth2_connector/oauth2connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def load(self, key: str, **kwargs) -> Any:
"""


class SecretKeeperMissingError(Exception):
"""Raised when secret_keeper is not set on oauth2 connector"""


class OAuth2ConnectorConfig(BaseModel):
client_id: str
client_secret: SecretStr
Expand All @@ -55,8 +59,8 @@ def __init__(
scope: str,
config: OAuth2ConnectorConfig,
redirect_uri: str,
secrets_keeper: SecretsKeeper,
token_url: str,
secrets_keeper: SecretsKeeper | None = None,
):
self.auth_flow_id = auth_flow_id
self.authorization_url = authorization_url
Expand All @@ -66,6 +70,11 @@ def __init__(
self.token_url = token_url
self.redirect_uri = redirect_uri

def _secrets_keeper(self) -> SecretsKeeper:
if self.secrets_keeper is None:
raise SecretKeeperMissingError("Secret keeper is not set on oauth2 connector.")
return self.secrets_keeper

def build_authorization_url(self, **kwargs) -> str:
"""Build an authorization request that will be sent to the client."""
from authlib.common.security import generate_token
Expand All @@ -79,7 +88,7 @@ def build_authorization_url(self, **kwargs) -> str:
state = {"token": generate_token(), **kwargs}
uri, state = client.create_authorization_url(self.authorization_url, state=JsonWrapper.dumps(state))

self.secrets_keeper.save(self.auth_flow_id, {"state": state})
self._secrets_keeper().save(self.auth_flow_id, {"state": state})
return uri

def retrieve_tokens(self, authorization_response: str, **kwargs):
Expand All @@ -90,7 +99,7 @@ def retrieve_tokens(self, authorization_response: str, **kwargs):
client_secret=self.config.client_secret.get_secret_value(),
redirect_uri=self.redirect_uri,
)
saved_flow = self.secrets_keeper.load(self.auth_flow_id)
saved_flow = self._secrets_keeper().load(self.auth_flow_id)
if saved_flow is None:
raise AuthFlowNotFound()
assert JsonWrapper.loads(saved_flow["state"])["token"] == JsonWrapper.loads(url_params["state"][0])["token"]
Expand All @@ -102,7 +111,7 @@ def retrieve_tokens(self, authorization_response: str, **kwargs):
client_secret=self.config.client_secret.get_secret_value(),
**kwargs,
)
self.secrets_keeper.save(self.auth_flow_id, token)
self._secrets_keeper().save(self.auth_flow_id, token)

# Deprecated
def get_access_token(self) -> str:
Expand All @@ -111,7 +120,7 @@ def get_access_token(self) -> str:
instance_url parameters are return by service, better to use it
new method get_access_data return all information to connect (secret and instance_url)
"""
token = self.secrets_keeper.load(self.auth_flow_id)
token = self._secrets_keeper().load(self.auth_flow_id)

if "expires_at" in token:
expires_at = token["expires_at"]
Expand All @@ -130,16 +139,16 @@ def get_access_token(self) -> str:
client_secret=self.config.client_secret.get_secret_value(),
)
new_token = client.refresh_token(self.token_url, refresh_token=token["refresh_token"])
self.secrets_keeper.save(self.auth_flow_id, new_token)
self._secrets_keeper().save(self.auth_flow_id, new_token)

return self.secrets_keeper.load(self.auth_flow_id)["access_token"]
return self._secrets_keeper().load(self.auth_flow_id)["access_token"]

def get_access_data(self):
"""
Returns the access_token to use to access resources
If necessary, this token will be refreshed
"""
access_data = self.secrets_keeper.load(self.auth_flow_id)
access_data = self._secrets_keeper().load(self.auth_flow_id)

logging.getLogger(__name__).debug("Refresh and get access data")

Expand All @@ -155,8 +164,8 @@ def get_access_data(self):
connection_data = client.refresh_token(self.token_url, refresh_token=access_data["refresh_token"])
logging.getLogger(__name__).debug(f"Refresh and get access data new token {str(connection_data)}")

self.secrets_keeper.save(self.auth_flow_id, connection_data)
secrets = self.secrets_keeper.load(self.auth_flow_id)
self._secrets_keeper().save(self.auth_flow_id, connection_data)
secrets = self._secrets_keeper().load(self.auth_flow_id)

logging.getLogger(__name__).debug("Refresh and get data finished")
return secrets
Expand All @@ -165,7 +174,7 @@ def get_refresh_token(self) -> str:
"""
Return the refresh token, used to obtain an access token
"""
return self.secrets_keeper.load(self.auth_flow_id)["refresh_token"]
return self._secrets_keeper().load(self.auth_flow_id)["refresh_token"]


class NoOAuth2RefreshToken(Exception):
Expand Down

0 comments on commit 4313dc4

Please sign in to comment.