Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
oauth2_scope = ""
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri: str | None = None
oauth2_token_request_type = "data"

# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
Expand Down Expand Up @@ -525,6 +526,9 @@ def get_oauth2_config(cls) -> OAuth2ClientConfig | None:
"token_request_uri",
cls.oauth2_token_request_uri,
),
"request_content_type": db_engine_spec_config.get(
"request_content_type", cls.oauth2_token_request_type
),
}

return config
Expand Down Expand Up @@ -562,18 +566,16 @@ def get_oauth2_token(
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
},
timeout=timeout,
)
return response.json()
req_body = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
Comment on lines 576 to 578
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

I think that the most common (standard?) workflow is to send form encoded data, and not JSON. It's just that the first OAuth2 implementation done for Superset targeted GSheets, and Google uses JSON instead of form encoded (it's the same for BigQuery, for example). But now changing the default would break existing databases, so we need to leave JSON as the default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that data is likely the more common implementation. As we're still in early days, I feel it's well motivated to introduce a breaking change, as long as we have an UPDATING.md explaining the required steps to get it to work again. This in the interest of having a clean/idiomatic API in the long term..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@villebro yeah, that sounds good to me.


@classmethod
def get_oauth2_fresh_token(
Expand All @@ -586,17 +588,15 @@ def get_oauth2_fresh_token(
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
timeout=timeout,
)
return response.json()
req_body = {
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()

@classmethod
def get_allows_alias_in_select(
Expand Down
27 changes: 26 additions & 1 deletion superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
import numpy as np
import pandas as pd
import pyarrow as pa
from flask import ctx, current_app, Flask, g
import requests
from flask import copy_current_request_context, ctx, current_app, Flask, g
from sqlalchemy import text
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
from trino.exceptions import HttpError

from superset import db
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
Expand Down Expand Up @@ -60,11 +62,28 @@
logger = logging.getLogger(__name__)


class CustomTrinoAuthErrorMeta(type):
def __instancecheck__(cls, instance: object) -> bool:
logger.info("is this being called?")
return isinstance(
instance, HttpError
) and "error 401: b'Invalid credentials'" in str(instance)


class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta):
pass


class TrinoEngineSpec(PrestoBaseEngineSpec):
engine = "trino"
engine_name = "Trino"
allows_alias_to_source_column = False

# OAuth 2.0
supports_oauth2 = True
oauth2_exception = TrinoAuthError
oauth2_token_request_type = "data"

@classmethod
def get_extra_table_metadata(
cls,
Expand Down Expand Up @@ -140,6 +159,10 @@ def update_impersonation_config(
# Set principal_username=$effective_username
if backend_name == "trino" and username is not None:
connect_args["user"] = username
if access_token is not None:
http_session = requests.Session()
http_session.headers.update({"Authorization": f"Bearer {access_token}"})
connect_args["http_session"] = http_session

@classmethod
def get_url_for_impersonation(
Expand All @@ -152,6 +175,7 @@ def get_url_for_impersonation(
"""
Return a modified URL with the username set.

:param access_token: Personal access token for OAuth2
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
Expand Down Expand Up @@ -226,6 +250,7 @@ def execute_with_cursor(
execute_result: dict[str, Any] = {}
execute_event = threading.Event()

@copy_current_request_context
def _execute(
results: dict[str, Any],
event: threading.Event,
Expand Down
4 changes: 4 additions & 0 deletions superset/superset_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
# expired access token.
token_request_uri: str

# Not all identity providers expect json. Keycloak expects a form encoded request,
# which in the `requests` package context means using the `data` param, not `json`.
request_content_type: str


class OAuth2TokenResponse(TypedDict, total=False):
"""
Expand Down
7 changes: 6 additions & 1 deletion superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import backoff
import jwt
from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema
from marshmallow import EXCLUDE, fields, post_load, Schema, validate

from superset import db
from superset.distributed_lock import KeyValueDistributedLock
Expand Down Expand Up @@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
request_content_type = fields.String(
Copy link
Contributor

@fisjac fisjac Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be mistaken, but I believe this validation is only applied to the client_info contained within encrypted_extra when provided by a user. Is the intent that this is where the request_content_type is going to be provided, or is it going to be set as a default for the Trino engine spec?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think the idea is that DB engine specs can have a default content type (defined in the oauth2_token_request_type class attribute), but it can be overridden on a per-database basis by setting it in the encrypted_extra.

required=False,
load_default=lambda: "json",
validate=validate.OneOf(["json", "data"]),
)
1 change: 1 addition & 0 deletions tests/unit_tests/db_engine_specs/test_gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
"token_request_uri": "https://oauth2.googleapis.com/token",
"request_content_type": "json",
}


Expand Down
117 changes: 90 additions & 27 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
SupersetDBAPIProgrammingError,
)
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
from superset.superset_typing import (
OAuth2ClientConfig,
ResultSetColumnType,
SQLAColumnType,
SQLType,
)
from superset.utils import json
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
Expand Down Expand Up @@ -421,21 +426,23 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id

mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
with app.test_request_context("/some/place/"):
mock_cursor.execute.side_effect = _mock_execute

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)

mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)


def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
Expand All @@ -446,23 +453,25 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
mock_cursor.query_id = None

mock_query = mocker.MagicMock()
g.some_value = "some_value"

def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"

with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
with app.test_request_context("/some/place/"):
g.some_value = "some_value"

with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)


def test_get_columns(mocker: MockerFixture):
Expand Down Expand Up @@ -784,3 +793,57 @@ def test_where_latest_partition(
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
)


@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for Trino OAuth2.
"""
return {
"id": "trino",
"secret": "very-secret",
"scope": "",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
"request_content_type": "data",
}


def test_get_oauth2_token(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec

requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}

assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://trino.auth.server.example/master/protocol/openid-connect/token",
data={
"code": "code",
"client_id": "trino",
"client_secret": "very-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
1 change: 1 addition & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def test_get_oauth2_config(app_context: None) -> None:
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:USERADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
"request_content_type": "json",
}


Expand Down
Loading