diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 373c10f3266ea..1e835b4673dab 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -32,6 +32,7 @@ from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet from airflow.secrets.cache import SecretCache +from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.module_loading import import_string @@ -480,6 +481,34 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection: def to_dict(self) -> dict[str, Any]: return {"conn_id": self.conn_id, "description": self.description, "uri": self.get_uri()} + def to_json_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[str, Any]: + """ + Convert Connection to json-serializable dictionary. + + :param prune_empty: Whether or not remove empty values. + :param validate: Validate dictionary is JSON-serializable + + :meta private: + """ + conn = { + "conn_id": self.conn_id, + "conn_type": self.conn_type, + "description": self.description, + "host": self.host, + "login": self.login, + "password": self.password, + "schema": self.schema, + "port": self.port, + } + if prune_empty: + conn = prune_dict(val=conn, mode="strict") + if (extra := self.extra_dejson) or not prune_empty: + conn["extra"] = extra + + if validate: + json.dumps(conn) + return conn + @classmethod def from_json(cls, value, conn_id=None) -> Connection: kwargs = json.loads(value) @@ -496,3 +525,9 @@ def from_json(cls, value, conn_id=None) -> Connection: except ValueError: raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.") return Connection(conn_id=conn_id, **kwargs) + + def as_json(self) -> str: + """Convert Connection to JSON-string object.""" + conn = self.to_json_dict(prune_empty=True, validate=False) + conn.pop("conn_id", None) + return json.dumps(conn) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d2a72cb1ca296..b5935b14a910b 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -498,7 +498,7 @@ def serialize( type_=DAT.SIMPLE_TASK_INSTANCE, ) elif isinstance(var, Connection): - return cls._encode(var.to_dict(), type_=DAT.CONNECTION) + return cls._encode(var.to_json_dict(validate=True), type_=DAT.CONNECTION) elif use_pydantic_models and _ENABLE_AIP_44: def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]: diff --git a/docs/apache-airflow/howto/connection.rst b/docs/apache-airflow/howto/connection.rst index 2c07b03b6d396..f90f7e96b10e5 100644 --- a/docs/apache-airflow/howto/connection.rst +++ b/docs/apache-airflow/howto/connection.rst @@ -66,6 +66,43 @@ If serializing with JSON: } }' +Generating a JSON connection representation +""""""""""""""""""""""""""""""""""""""""""" + +.. versionadded:: 2.8.0 + + +To make connection JSON generation easier, the :py:class:`~airflow.models.connection.Connection` class has a +convenience property :py:meth:`~airflow.models.connection.Connection.as_json`. It can be used like so: + +.. code-block:: pycon + + >>> from airflow.models.connection import Connection + >>> c = Connection( + ... conn_id="some_conn", + ... conn_type="mysql", + ... description="connection description", + ... host="myhost.com", + ... login="myname", + ... password="mypassword", + ... extra={"this_param": "some val", "that_param": "other val*"}, + ... ) + >>> print(f"AIRFLOW_CONN_{c.conn_id.upper()}='{c.as_json()}'") + AIRFLOW_CONN_SOME_CONN='{"conn_type": "mysql", "description": "connection description", "host": "myhost.com", "login": "myname", "password": "mypassword", "extra": {"this_param": "some val", "that_param": "other val*"}}' + +In addition, same approach could be used to convert Connection from URI format to JSON format + +.. code-block:: pycon + + >>> from airflow.models.connection import Connection + >>> c = Connection( + ... conn_id="awesome_conn", + ... description="Example Connection", + ... uri="aws://AKIAIOSFODNN7EXAMPLE:wJalrXUtnFEMI%2FK7MDENG%2FbPxRfiCYEXAMPLEKEY@/?__extra__=%7B%22region_name%22%3A+%22eu-central-1%22%2C+%22config_kwargs%22%3A+%7B%22retries%22%3A+%7B%22mode%22%3A+%22standard%22%2C+%22max_attempts%22%3A+10%7D%7D%7D", + ... ) + >>> print(f"AIRFLOW_CONN_{c.conn_id.upper()}='{c.as_json()}'") + AIRFLOW_CONN_AWESOME_CONN='{"conn_type": "aws", "description": "Example Connection", "host": "", "login": "AKIAIOSFODNN7EXAMPLE", "password": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "schema": "", "extra": {"region_name": "eu-central-1", "config_kwargs": {"retries": {"mode": "standard", "max_attempts": 10}}}}' + URI format example ^^^^^^^^^^^^^^^^^^ diff --git a/tests/always/test_connection.py b/tests/always/test_connection.py index 9b0df4ea891d3..e01907a8d3365 100644 --- a/tests/always/test_connection.py +++ b/tests/always/test_connection.py @@ -790,3 +790,59 @@ def test_get_uri_no_conn_type(self): assert Connection(host="abc").get_uri() == "//abc" # parsing back as conn still works assert Connection(uri="//abc").host == "abc" + + @pytest.mark.parametrize( + "conn, expected_json", + [ + pytest.param(Connection(), "{}", id="empty"), + pytest.param(Connection(host="apache.org", extra={}), '{"host": "apache.org"}', id="empty-extra"), + pytest.param( + Connection(conn_type="foo", login="", password="p@$$"), + '{"conn_type": "foo", "login": "", "password": "p@$$"}', + id="some-fields", + ), + pytest.param( + Connection( + conn_type="bar", + description="Sample Description", + host="example.org", + login="user", + password="p@$$", + schema="schema", + port=777, + extra={"foo": "bar", "answer": 42}, + ), + json.dumps( + { + "conn_type": "bar", + "description": "Sample Description", + "host": "example.org", + "login": "user", + "password": "p@$$", + "schema": "schema", + "port": 777, + "extra": {"foo": "bar", "answer": 42}, + } + ), + id="all-fields", + ), + pytest.param( + Connection(uri="aws://"), + # During parsing URI some of the fields evaluated as an empty strings + '{"conn_type": "aws", "host": "", "schema": ""}', + id="uri", + ), + ], + ) + def test_as_json_from_connection(self, conn: Connection, expected_json): + result = conn.as_json() + assert result == expected_json + restored_conn = Connection.from_json(result) + + assert restored_conn.conn_type == conn.conn_type + assert restored_conn.description == conn.description + assert restored_conn.host == conn.host + assert restored_conn.password == conn.password + assert restored_conn.schema == conn.schema + assert restored_conn.port == conn.port + assert restored_conn.extra_dejson == conn.extra_dejson diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index e05f69114c6da..0af29e8ebc67e 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -36,7 +36,7 @@ from airflow.models.xcom_arg import XComArg from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator -from airflow.serialization.enums import DagAttributeTypes as DAT +from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.serialization.pydantic.job import JobPydantic @@ -213,6 +213,27 @@ def test_serialize_deserialize(input, encoded_type, cmp_func): json.dumps(serialized) # does not raise +@pytest.mark.parametrize( + "conn_uri", + [ + pytest.param("aws://", id="only-conn-type"), + pytest.param("postgres://username:password@ec2.compute.com:5432/the_database", id="all-non-extra"), + pytest.param( + "///?__extra__=%7B%22foo%22%3A+%22bar%22%2C+%22answer%22%3A+42%2C+%22" + "nullable%22%3A+null%2C+%22empty%22%3A+%22%22%2C+%22zero%22%3A+0%7D", + id="extra", + ), + ], +) +def test_backcompat_deserialize_connection(conn_uri): + """Test deserialize connection which serialised by previous serializer implementation.""" + from airflow.serialization.serialized_objects import BaseSerialization + + conn_obj = {Encoding.TYPE: DAT.CONNECTION, Encoding.VAR: {"conn_id": "TEST_ID", "uri": conn_uri}} + deserialized = BaseSerialization.deserialize(conn_obj) + assert deserialized.get_uri() == conn_uri + + @pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") @pytest.mark.parametrize( "input, pydantic_class, encoded_type, cmp_func",