diff --git a/superset/models/slice.py b/superset/models/slice.py index f47f424fc3e2..f6379f0e77d9 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -363,7 +363,7 @@ def get_query_context_factory(self) -> QueryContextFactory: @classmethod def get(cls, id_or_uuid: str) -> Slice: - qry = db.session.query(Slice).filter_by(id_or_uuid_filter(id_or_uuid)) + qry = db.session.query(Slice).filter(id_or_uuid_filter(id_or_uuid)) return qry.one_or_none() diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 1b95ed1f6398..381e3f684606 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import uuid from io import BytesIO from unittest import mock from unittest.mock import patch @@ -1076,6 +1077,43 @@ def test_get_chart_not_found(self): rv = self.get_assert_metric(uri, "get") assert rv.status_code == 404 + @parameterized.expand( + [ + ("by_id", lambda chart: str(chart.id), "id"), + ( + "by_uuid", + lambda chart: str(chart.uuid) if chart.uuid else pytest.skip("No UUID"), + "uuid", + ), + ] + ) + def test_slice_get_existing(self, test_name, get_identifier, field_type): + """Test Slice.get() successfully retrieves existing charts.""" + admin = self.get_user("admin") + chart = self.insert_chart(f"test_slice_get_{field_type}", [admin.id], 1) + + identifier = get_identifier(chart) + result = Slice.get(identifier) + + assert result is not None + assert result.id == chart.id + if field_type == "uuid" and chart.uuid: + assert result.uuid == chart.uuid + + db.session.delete(chart) + db.session.commit() + + @parameterized.expand( + [ + ("nonexistent_id", "999999"), + ("nonexistent_uuid", str(uuid.uuid4())), + ] + ) + def test_slice_get_not_found(self, test_name, identifier): + """Test Slice.get() returns None for non-existent identifiers.""" + result = Slice.get(identifier) + assert result is None + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_chart_no_data_access(self): """ diff --git a/tests/unit_tests/models/slice_test.py b/tests/unit_tests/models/slice_test.py new file mode 100644 index 000000000000..3b9f0666c958 --- /dev/null +++ b/tests/unit_tests/models/slice_test.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from parameterized import parameterized + +from superset.models.slice import id_or_uuid_filter, Slice + + +class TestSlice: + """Test cases for Slice model functionality.""" + + @parameterized.expand( + [ + ("numeric_id", "123"), + ("uuid_string", "550e8400-e29b-41d4-a716-446655440000"), + ] + ) + def test_slice_get_calls_filter_correctly(self, test_name, id_or_uuid): + """Test Slice.get() calls filter() correctly for ID and UUID.""" + with patch("superset.models.slice.db") as mock_db: + # Setup mock chain + mock_query = MagicMock() + mock_filtered_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_filtered_query + mock_filtered_query.one_or_none.return_value = None + + # Call the method + result = Slice.get(id_or_uuid) + + # Verify correct methods called + mock_db.session.query.assert_called_once_with(Slice) + mock_query.filter.assert_called_once() # Not filter_by! + mock_filtered_query.one_or_none.assert_called_once() + assert result is None + + @parameterized.expand( + [ + ("numeric_id", "123"), + ("large_id", "999999"), + ("uuid_string", str(uuid.uuid4())), + ] + ) + def test_slice_get_no_type_error(self, test_name, input_value): + """Verify Slice.get() doesn't raise TypeError for various inputs.""" + try: + result = Slice.get(input_value) + # Success - no TypeError, result can be None or a Slice + assert result is None or hasattr(result, "id") + except TypeError as e: + if "filter_by() takes 1 positional argument" in str(e): + pytest.fail( + f"filter_by() bug exists: Slice.get('{input_value}') failed with {e}" # noqa: E501 + ) + else: + raise + + @parameterized.expand( + [ + ("numeric_id", "123"), + ("uuid_format", "550e8400-e29b-41d4-a716-446655440000"), + ("invalid_string", "not-a-number"), + ] + ) + def test_id_or_uuid_filter(self, test_name, input_value): + """Test id_or_uuid_filter returns correct BinaryExpression.""" + result = id_or_uuid_filter(input_value) + assert result is not None