Skip to content

Commit

Permalink
AIP-72: Add support for fetching variables and connections in Supervisor
Browse files Browse the repository at this point in the history
- Updated `VariableOperations` and `ConnectionOperations` in `Client`:
  - Added `get` methods for fetching variable and connection details.
- Refactored communication protocol (`comms.py`):
  - Unified result models (`ConnectionResult`, `VariableResult`, `XComResult`) extending auto-generated models.
  - Renamed `ReadXCom` to `GetXCom` for consistency with other request models (`GetConnection`, `GetVariable`).
- Updated `WatchedSubprocess` in `supervisor.py`:
  - Integrated `handle_requests` to process `GetVariable` and `GetConnection` messages.

and some minor refactors
  • Loading branch information
kaxil committed Nov 21, 2024
1 parent a130711 commit 42ddd22
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 43 deletions.
27 changes: 23 additions & 4 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TIHeartbeatInfo,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
VariableResponse,
)
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
Expand Down Expand Up @@ -124,17 +125,29 @@ def heartbeat(self, id: uuid.UUID, pid: int):


class ConnectionOperations:
__slots__ = ("client", "decoder")
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, id: str) -> ConnectionResponse:
def get(self, conn_id: str) -> ConnectionResponse:
"""Get a connection from the API server."""
resp = self.client.get(f"connection/{id}")
resp = self.client.get(f"connections/{conn_id}")
return ConnectionResponse.model_validate_json(resp.read())


class VariableOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, key: str) -> VariableResponse:
"""Get a variable from the API server."""
resp = self.client.get(f"variables/{key}")
return VariableResponse.model_validate_json(resp.read())


class BearerAuth(httpx.Auth):
def __init__(self, token: str):
self.token: str = token
Expand Down Expand Up @@ -186,9 +199,15 @@ def task_instances(self) -> TaskInstanceOperations:
@lru_cache() # type: ignore[misc]
@property
def connections(self) -> ConnectionOperations:
"""Operations related to TaskInstances."""
"""Operations related to Connections."""
return ConnectionOperations(self)

@lru_cache() # type: ignore[misc]
@property
def variables(self) -> VariableOperations:
"""Operations related to Variables."""
return VariableOperations(self)


# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
Expand Down
35 changes: 20 additions & 15 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@

from __future__ import annotations

from typing import Annotated, Any, Literal, Union
from typing import Annotated, Literal, Union

from pydantic import BaseModel, ConfigDict, Field

from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState # noqa: TCH001
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
TaskInstance,
TerminalTIState,
VariableResponse,
XComResponse,
)


class StartupDetails(BaseModel):
Expand All @@ -64,23 +70,22 @@ class StartupDetails(BaseModel):
type: Literal["StartupDetails"] = "StartupDetails"


class XComResponse(BaseModel):
class XComResult(XComResponse):
"""Response to ReadXCom request."""

key: str
value: Any
type: Literal["XComResult"] = "XComResult"

type: Literal["XComResponse"] = "XComResponse"

class ConnectionResult(ConnectionResponse):
type: Literal["ConnectionResult"] = "ConnectionResult"

class ConnectionResponse(BaseModel):
conn: Any

type: Literal["ConnectionResponse"] = "ConnectionResponse"
class VariableResult(VariableResponse):
type: Literal["VariableResult"] = "VariableResult"


ToTask = Annotated[
Union[StartupDetails, XComResponse, ConnectionResponse],
Union[StartupDetails, XComResult, ConnectionResult, VariableResult],
Field(discriminator="type"),
]

Expand All @@ -98,22 +103,22 @@ class TaskState(BaseModel):
type: Literal["TaskState"] = "TaskState"


class ReadXCom(BaseModel):
class GetXCom(BaseModel):
key: str
type: Literal["ReadXCom"] = "ReadXCom"
type: Literal["GetXCom"] = "GetXCom"


class GetConnection(BaseModel):
id: str
conn_id: str
type: Literal["GetConnection"] = "GetConnection"


class GetVariable(BaseModel):
id: str
key: str
type: Literal["GetVariable"] = "GetVariable"


ToSupervisor = Annotated[
Union[TaskState, ReadXCom, GetConnection, GetVariable],
Union[TaskState, GetXCom, GetConnection, GetVariable],
Field(discriminator="type"),
]
26 changes: 9 additions & 17 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
from airflow.sdk.api.client import Client
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
ConnectionResponse,
GetConnection,
GetVariable,
StartupDetails,
ToSupervisor,
)
Expand Down Expand Up @@ -480,10 +480,7 @@ def __repr__(self) -> str:
return rep + " >"

def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]:
encoder = ConnectionResponse.model_dump_json
# Use a buffer to avoid small allocations
buffer = bytearray(64)

"""Handle incoming requests from the task process, respond with the appropriate data."""
decoder = TypeAdapter[ToSupervisor](ToSupervisor)

while True:
Expand All @@ -495,28 +492,23 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
log.exception("Unable to decode message", line=line)
continue

# if isinstnace(msg, TaskState):
# if isinstance(msg, TaskState):
# self._terminal_state = msg.state
# elif isinstance(msg, ReadXCom):
# resp = XComResponse(key="secret", value=True)
# encoder.encode_into(resp, buffer)
# self.stdin.write(buffer + b"\n")
if isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.id)
resp = ConnectionResponse(conn=conn)
encoded_resp = encoder(resp)
buffer.extend(encoded_resp.encode())
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
else:
log.error("Unhandled request", msg=msg)
continue

buffer.extend(b"\n")
self.stdin.write(buffer)

# Ensure the buffer doesn't grow and stay large if a large payload is used. This won't grow it
# larger than it is, but it will shrink it
if len(buffer) > 1024:
buffer = buffer[:1024]
self.stdin.write(resp + b"\n")


# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read
Expand Down
11 changes: 10 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ def send_request(self, log: Logger, msg: ToSupervisor):
self.request_socket.write(encoded_msg)


# This global variable will be used by Connection/Variable classes etc to send requests to
# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
# to send requests back to the supervisor process.
#
# Why it needs to be a global:
# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests
# to the parent process during task execution.
# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the
# deeply nested execution stack.
# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
# accessible wherever needed during task execution without modifying every layer of the call stack.
SUPERVISOR_COMMS: CommsDecoder

# State machine!
Expand Down
59 changes: 59 additions & 0 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError
from airflow.sdk.api.datamodels._generated import VariableResponse


class TestClient:
Expand Down Expand Up @@ -74,3 +75,61 @@ def handle_request(request: httpx.Request) -> httpx.Response:
client.get("http://error")
assert err.value.args == ("Not found",)
assert err.value.detail is None


def make_client(transport: httpx.MockTransport) -> Client:
"""Get a client with a custom transport"""
return Client(base_url="test://server", token="", transport=transport)


class TestVariableOperations:
"""
Test that the VariableOperations class works as expected. While the operations are simple, it
still catches the basic functionality of the client for variables including endpoint and
response parsing.
"""

def test_variable_get_success(self):
# Simulate a successful response from the server with a variable
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=200,
json={"key": "test_key", "value": "test_value"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
result = client.variables.get(key="test_key")

assert isinstance(result, VariableResponse)
assert result.key == "test_key"
assert result.value == "test_value"

def test_variable_not_found(self):
# Simulate a 404 response from the server
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/non_existent_var":
return httpx.Response(
status_code=404,
json={
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError) as err:
client.variables.get(key="non_existent_var")

assert err.value.response.status_code == 404
assert err.value.detail == {
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
}
Loading

0 comments on commit 42ddd22

Please sign in to comment.