From 01b4ebb394d616a3532fd30ee82d67f43cc16481 Mon Sep 17 00:00:00 2001 From: Danylo Korostil Date: Sun, 24 Aug 2025 15:43:45 +0300 Subject: [PATCH 1/2] feat(api): dataset API uuid support replaced Raw with Union for uuid_or_id Missing import added marshmallow-union marshmallow union --- docs/static/resources/openapi.json | 7 +++- pyproject.toml | 3 +- requirements/base.txt | 7 +++- requirements/development.txt | 8 ++++ superset/charts/schemas.py | 6 ++- superset/common/query_context_factory.py | 2 +- superset/common/query_object_factory.py | 2 +- superset/daos/base.py | 29 ++++++++++++++ superset/daos/datasource.py | 31 +++++++++++---- superset/daos/exceptions.py | 5 +++ superset/datasets/api.py | 38 +++++++++---------- superset/utils/core.py | 2 +- superset/views/datasource/utils.py | 2 +- .../integration_tests/query_context_tests.py | 2 +- tests/unit_tests/datasource/dao_tests.py | 10 ++--- 15 files changed, 110 insertions(+), 44 deletions(-) diff --git a/docs/static/resources/openapi.json b/docs/static/resources/openapi.json index bb8858fb6ccf..831b120bf05b 100644 --- a/docs/static/resources/openapi.json +++ b/docs/static/resources/openapi.json @@ -962,8 +962,11 @@ "ChartDataDatasource": { "properties": { "id": { - "description": "Datasource id", - "type": "integer" + "description": "Datasource id/uuid", + "oneOf": [ + { "type": "integer" }, + { "type": "string" } + ] }, "type": { "description": "Datasource type", diff --git a/pyproject.toml b/pyproject.toml index c0673c694681..9160340df630 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ dependencies = [ "markdown>=3.0", # marshmallow>=4 has issues: https://github.com/apache/superset/issues/33162 "marshmallow>=3.0, <4", + "marshmallow-union>=0.1", "msgpack>=1.0.0, <1.1", "nh3>=0.2.11, <0.3", "numpy>1.23.5, <2.3", @@ -222,7 +223,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = "superset" -known_third_party = "alembic, apispec, backoff, celery, click, colorama, cron_descriptor, croniter, cryptography, dateutil, deprecation, flask, flask_appbuilder, flask_babel, flask_caching, flask_compress, flask_jwt_extended, flask_login, flask_migrate, flask_sqlalchemy, flask_talisman, flask_testing, flask_wtf, freezegun, geohash, geopy, holidays, humanize, isodate, jinja2, jwt, markdown, markupsafe, marshmallow, msgpack, nh3, numpy, pandas, parameterized, parsedatetime, pgsanity, polyline, prison, progress, pyarrow, sqlalchemy_bigquery, pyhive, pyparsing, pytest, pytest_mock, pytz, redis, requests, selenium, setuptools, shillelagh, simplejson, slack, sqlalchemy, sqlalchemy_utils, typing_extensions, urllib3, werkzeug, wtforms, wtforms_json, yaml" +known_third_party = "alembic, apispec, backoff, celery, click, colorama, cron_descriptor, croniter, cryptography, dateutil, deprecation, flask, flask_appbuilder, flask_babel, flask_caching, flask_compress, flask_jwt_extended, flask_login, flask_migrate, flask_sqlalchemy, flask_talisman, flask_testing, flask_wtf, freezegun, geohash, geopy, holidays, humanize, isodate, jinja2, jwt, markdown, markupsafe, marshmallow, marshmallow-union, msgpack, nh3, numpy, pandas, parameterized, parsedatetime, pgsanity, polyline, prison, progress, pyarrow, sqlalchemy_bigquery, pyhive, pyparsing, pytest, pytest_mock, pytz, redis, requests, selenium, setuptools, shillelagh, simplejson, slack, sqlalchemy, sqlalchemy_utils, typing_extensions, urllib3, werkzeug, wtforms, wtforms_json, yaml" multi_line_output = 3 order_by_type = false diff --git a/requirements/base.txt b/requirements/base.txt index c64d081adddb..7cbe444bfa2f 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -154,6 +154,7 @@ greenlet==3.1.1 # via # apache-superset (pyproject.toml) # shillelagh + # sqlalchemy gunicorn==23.0.0 # via apache-superset (pyproject.toml) h11==0.16.0 @@ -215,10 +216,13 @@ marshmallow==3.26.1 # apache-superset (pyproject.toml) # flask-appbuilder # marshmallow-sqlalchemy + # marshmallow-union marshmallow-sqlalchemy==1.4.0 # via # -r requirements/base.in # flask-appbuilder +marshmallow-union==0.1.15 + # via apache-superset (pyproject.toml) mdurl==0.1.2 # via markdown-it-py msgpack==1.0.8 @@ -235,6 +239,7 @@ numpy==1.26.4 # bottleneck # numexpr # pandas + # pyarrow odfpy==1.4.1 # via pandas openapi-schema-validator==0.6.3 @@ -267,7 +272,7 @@ parsedatetime==2.6 pgsanity==0.2.9 # via apache-superset (pyproject.toml) pillow==11.3.0 - # via apache_superset (pyproject.toml) + # via apache-superset (pyproject.toml) platformdirs==4.3.8 # via requests-cache ply==3.11 diff --git a/requirements/development.txt b/requirements/development.txt index b210597e7cbb..1fbbfbc1d7a8 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -313,6 +313,7 @@ greenlet==3.1.1 # apache-superset # gevent # shillelagh + # sqlalchemy grpcio==1.71.0 # via # apache-superset @@ -429,10 +430,15 @@ marshmallow==3.26.1 # apache-superset # flask-appbuilder # marshmallow-sqlalchemy + # marshmallow-union marshmallow-sqlalchemy==1.4.0 # via # -c requirements/base.txt # flask-appbuilder +marshmallow-union==0.1.15 + # via + # -c requirements/base.txt + # apache-superset matplotlib==3.9.0 # via prophet mccabe==0.7.0 @@ -469,6 +475,7 @@ numpy==1.26.4 # pandas # pandas-gbq # prophet + # pyarrow oauthlib==3.2.2 # via requests-oauthlib odfpy==1.4.1 @@ -539,6 +546,7 @@ pgsanity==0.2.9 # apache-superset pillow==11.3.0 # via + # -c requirements/base.txt # apache-superset # matplotlib pip==25.1.1 diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 3a05c3ff1767..e794f1809109 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -24,6 +24,7 @@ from flask_babel import gettext as _ from marshmallow import EXCLUDE, fields, post_load, Schema, validate from marshmallow.validate import Length, Range +from marshmallow_union import Union from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.db_engine_specs.base import builtin_time_grains @@ -1130,8 +1131,9 @@ class AnnotationLayerSchema(Schema): class ChartDataDatasourceSchema(Schema): description = "Chart datasource" - id = fields.Integer( - metadata={"description": "Datasource id"}, + id = Union( + [fields.Integer(), fields.UUID()], + metadata={"description": "Datasource id or uuid"}, required=True, ) type = fields.String( diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 9205ba418efb..81f460c08b4c 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -106,7 +106,7 @@ def create( # pylint: disable=too-many-arguments def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: return DatasourceDAO.get_datasource( datasource_type=DatasourceType(datasource["type"]), - datasource_id=int(datasource["id"]), + database_id_or_uuid=datasource["id"], ) def _get_slice(self, slice_id: Any) -> Slice | None: diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index 063ba023481e..31d2843da56b 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -91,7 +91,7 @@ def create( # pylint: disable=too-many-arguments def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: return self._datasource_dao.get_datasource( datasource_type=DatasourceType(datasource["type"]), - datasource_id=int(datasource["id"]), + database_id_or_uuid=datasource["id"], ) def _process_extras( diff --git a/superset/daos/base.py b/superset/daos/base.py index e393034062b8..79adeb6f548f 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -43,12 +43,41 @@ class BaseDAO(Generic[T]): Child classes can register base filtering to be applied to all filter methods """ id_column_name = "id" + uuid_column_name = "uuid" def __init_subclass__(cls) -> None: cls.model_cls = get_args( cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member )[0] + @classmethod + def find_by_id_or_uuid( + cls, + model_id_or_uuid: str, + skip_base_filter: bool = False, + ) -> T | None: + """ + Find a model by id, if defined applies `base_filter` + """ + query = db.session.query(cls.model_cls) + if cls.base_filter and not skip_base_filter: + data_model = SQLAInterface(cls.model_cls, db.session) + query = cls.base_filter( # pylint: disable=not-callable + cls.id_column_name, data_model + ).apply(query, None) + id_column = getattr(cls.model_cls, cls.id_column_name) + uuid_column = getattr(cls.model_cls, cls.uuid_column_name) + + if model_id_or_uuid.isdigit(): + filter = id_column == int(model_id_or_uuid) + else: + filter = uuid_column == model_id_or_uuid + try: + return query.filter(filter).one_or_none() + except StatementError: + # can happen if int is passed instead of a string or similar + return None + @classmethod def find_by_id( cls, diff --git a/superset/daos/datasource.py b/superset/daos/datasource.py index 9a48bb86b7b6..c754fd8e1f5f 100644 --- a/superset/daos/datasource.py +++ b/superset/daos/datasource.py @@ -16,12 +16,17 @@ # under the License. import logging +import uuid from typing import Union from superset import db from superset.connectors.sqla.models import SqlaTable from superset.daos.base import BaseDAO -from superset.daos.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError +from superset.daos.exceptions import ( + DatasourceNotFound, + DatasourceTypeNotSupportedError, + DatasourceValueIsIncorrect, +) from superset.models.sql_lab import Query, SavedQuery from superset.utils.core import DatasourceType @@ -41,22 +46,34 @@ class DatasourceDAO(BaseDAO[Datasource]): def get_datasource( cls, datasource_type: Union[DatasourceType, str], - datasource_id: int, + database_id_or_uuid: int | str, ) -> Datasource: if datasource_type not in cls.sources: raise DatasourceTypeNotSupportedError() + model = cls.sources[datasource_type] + + if str(database_id_or_uuid).isdigit(): + filter = model.id == int(database_id_or_uuid) + else: + try: + uuid.UUID(str(database_id_or_uuid)) # uuid validation + filter = model.uuid == database_id_or_uuid + except ValueError as err: + logger.warning( + f"database_id_or_uuid {database_id_or_uuid} isn't valid uuid" + ) + raise DatasourceValueIsIncorrect() from err + datasource = ( - db.session.query(cls.sources[datasource_type]) - .filter_by(id=datasource_id) - .one_or_none() + db.session.query(cls.sources[datasource_type]).filter(filter).one_or_none() ) if not datasource: logger.warning( - "Datasource not found datasource_type: %s, datasource_id: %s", + "Datasource not found datasource_type: %s, database_id_or_uuid: %s", datasource_type, - datasource_id, + database_id_or_uuid, ) raise DatasourceNotFound() diff --git a/superset/daos/exceptions.py b/superset/daos/exceptions.py index ebd20fee631a..82bff417e0dd 100644 --- a/superset/daos/exceptions.py +++ b/superset/daos/exceptions.py @@ -35,3 +35,8 @@ class DatasourceTypeNotSupportedError(DAOException): class DatasourceNotFound(DAOException): status = 404 message = "Datasource does not exist" + + +class DatasourceValueIsIncorrect(DAOException): + status = 422 + message = "Datasource value is neither id or uuid" diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 2afd0411c281..6821b837e3cc 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -702,7 +702,7 @@ def refresh(self, pk: int) -> Response: ) return self.response_422(message=str(ex)) - @expose("//related_objects", methods=("GET",)) + @expose("//related_objects", methods=("GET",)) @protect() @safe @statsd_metrics @@ -711,16 +711,16 @@ def refresh(self, pk: int) -> Response: f".related_objects", log_to_statsd=False, ) - def related_objects(self, pk: int) -> Response: + def related_objects(self, id_or_uuid: str) -> Response: """Get charts and dashboards count associated to a dataset. --- get: summary: Get charts and dashboards count associated to a dataset parameters: - in: path - name: pk + name: id_or_uuid schema: - type: integer + type: string responses: 200: 200: @@ -736,10 +736,10 @@ def related_objects(self, pk: int) -> Response: 500: $ref: '#/components/responses/500' """ - dataset = DatasetDAO.find_by_id(pk) + dataset = DatasetDAO.find_by_id_or_uuid(id_or_uuid) if not dataset: return self.response_404() - data = DatasetDAO.get_related_objects(pk) + data = DatasetDAO.get_related_objects(dataset.id) charts = [ { "id": chart.id, @@ -1081,7 +1081,7 @@ def warm_up_cache(self) -> Response: except CommandException as ex: return self.response(ex.status, message=ex.message) - @expose("/", methods=("GET",)) + @expose("/", methods=("GET",)) @protect() @safe @rison(get_item_schema) @@ -1091,7 +1091,7 @@ def warm_up_cache(self) -> Response: action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.get", log_to_statsd=False, ) - def get(self, pk: int, **kwargs: Any) -> Response: + def get(self, id_or_uuid: str, **kwargs: Any) -> Response: """Get a dataset. --- get: @@ -1100,9 +1100,9 @@ def get(self, pk: int, **kwargs: Any) -> Response: parameters: - in: path schema: - type: integer - description: The dataset ID - name: pk + type: string + description: Either the id of the dataset, or its uuid + name: id_or_uuid - in: query name: q content: @@ -1138,13 +1138,8 @@ def get(self, pk: int, **kwargs: Any) -> Response: 500: $ref: '#/components/responses/500' """ - item: SqlaTable | None = self.datamodel.get( - pk, - self._base_filters, - self.show_select_columns, - self.show_outer_default_load, - ) - if not item: + table = DatasetDAO.find_by_id_or_uuid(id_or_uuid) + if not table: return self.response_404() response: dict[str, Any] = {} @@ -1162,8 +1157,8 @@ def get(self, pk: int, **kwargs: Any) -> Response: else: show_model_schema = self.show_model_schema - response["id"] = pk - response[API_RESULT_RES_KEY] = show_model_schema.dump(item, many=False) + response["id"] = table.id + response[API_RESULT_RES_KEY] = show_model_schema.dump(table, many=False) # remove folders from resposne if `DATASET_FOLDERS` is disabled, so that it's # possible to inspect if the feature is supported or not @@ -1175,12 +1170,13 @@ def get(self, pk: int, **kwargs: Any) -> Response: if parse_boolean_string(request.args.get("include_rendered_sql")): try: - processor = get_template_processor(database=item.database) + processor = get_template_processor(database=table.database) response["result"] = self.render_dataset_fields( response["result"], processor ) except SupersetTemplateException as ex: return self.response_400(message=str(ex)) + return self.response(200, **response) @expose("//drill_info/", methods=("GET",)) diff --git a/superset/utils/core.py b/superset/utils/core.py index 81a4a7078d48..70b76d4118b0 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -209,7 +209,7 @@ class HeaderDataType(TypedDict): class DatasourceDict(TypedDict): type: str # todo(hugh): update this to be DatasourceType - id: int + id: int | str class AdhocFilterClause(TypedDict, total=False): diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 75c9eb7b0d3a..e0751bb19ab7 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -54,7 +54,7 @@ def get_samples( # pylint: disable=too-many-arguments ) -> dict[str, Any]: datasource = DatasourceDAO.get_datasource( datasource_type=datasource_type, - datasource_id=datasource_id, + database_id_or_uuid=str(datasource_id), ) limit_clause = get_limit_clause(page, per_page) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index c57ea11419d0..751f96631ead 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -153,7 +153,7 @@ def test_query_cache_key_changes_when_datasource_is_updated(self): # make temporary change and revert it to refresh the changed_on property datasource = DatasourceDAO.get_datasource( datasource_type=DatasourceType(payload["datasource"]["type"]), - datasource_id=payload["datasource"]["id"], + database_id_or_uuid=payload["datasource"]["id"], ) description_original = datasource.description datasource.description = "temporary description" diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py index bf486010abf6..17e170f12b07 100644 --- a/tests/unit_tests/datasource/dao_tests.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -76,7 +76,7 @@ def test_get_datasource_sqlatable(session_with_data: Session) -> None: result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.TABLE, - datasource_id=1, + database_id_or_uuid=1, ) assert 1 == result.id @@ -89,7 +89,7 @@ def test_get_datasource_query(session_with_data: Session) -> None: from superset.models.sql_lab import Query result = DatasourceDAO.get_datasource( - datasource_type=DatasourceType.QUERY, datasource_id=1 + datasource_type=DatasourceType.QUERY, database_id_or_uuid=1 ) assert result.id == 1 @@ -102,7 +102,7 @@ def test_get_datasource_saved_query(session_with_data: Session) -> None: result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.SAVEDQUERY, - datasource_id=1, + database_id_or_uuid=1, ) assert result.id == 1 @@ -116,7 +116,7 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None: assert isinstance( DatasourceDAO.get_datasource( datasource_type="table", - datasource_id=1, + database_id_or_uuid=1, ), SqlaTable, ) @@ -136,5 +136,5 @@ def test_not_found_datasource(session_with_data: Session) -> None: with pytest.raises(DatasourceNotFound): DatasourceDAO.get_datasource( datasource_type="table", - datasource_id=500000, + database_id_or_uuid=500000, ) From 24c81433c5c679b0b66272c6c33ec4b3d189940d Mon Sep 17 00:00:00 2001 From: Danylo Korostil Date: Thu, 28 Aug 2025 12:19:08 +0300 Subject: [PATCH 2/2] Comments fix --- superset/daos/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/daos/base.py b/superset/daos/base.py index e97bafd66d8c..2b40dc0e3334 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -61,7 +61,7 @@ def find_by_id_or_uuid( skip_base_filter: bool = False, ) -> T | None: """ - Find a model by id, if defined applies `base_filter` + Find a model by id or uuid, if defined applies `base_filter` """ query = db.session.query(cls.model_cls) if cls.base_filter and not skip_base_filter: @@ -79,7 +79,7 @@ def find_by_id_or_uuid( try: return query.filter(filter).one_or_none() except StatementError: - # can happen if int is passed instead of a string or similar + # can happen if neither uuid nor int is passed return None @classmethod