Skip to content

Commit

Permalink
AIP-72: Add support for fetching variables and connections in Supervisor
Browse files Browse the repository at this point in the history
- Updated `VariableOperations` and `ConnectionOperations` in `Client`:
  - Added `get` methods for fetching variable and connection details.
- Refactored communication protocol (`comms.py`):
  - Unified result models (`ConnectionResult`, `VariableResult`, `XComResult`) extending auto-generated models.
  - Renamed `ReadXCom` to `GetXCom` for consistency with other request models (`GetConnection`, `GetVariable`).
- Updated `WatchedSubprocess` in `supervisor.py`:
  - Integrated `handle_requests` to process `GetVariable` and `GetConnection` messages.

and some minor refactors
  • Loading branch information
kaxil committed Nov 21, 2024
1 parent d43052e commit 0175650
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 34 deletions.
27 changes: 23 additions & 4 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TIHeartbeatInfo,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
VariableResponse,
)
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
Expand Down Expand Up @@ -124,17 +125,29 @@ def heartbeat(self, id: uuid.UUID, pid: int):


class ConnectionOperations:
__slots__ = ("client", "decoder")
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, id: str) -> ConnectionResponse:
def get(self, conn_id: str) -> ConnectionResponse:
"""Get a connection from the API server."""
resp = self.client.get(f"connection/{id}")
resp = self.client.get(f"connections/{conn_id}")
return ConnectionResponse.model_validate_json(resp.read())


class VariableOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get(self, key: str) -> VariableResponse:
"""Get a variable from the API server."""
resp = self.client.get(f"variables/{key}")
return VariableResponse.model_validate_json(resp.read())


class BearerAuth(httpx.Auth):
def __init__(self, token: str):
self.token: str = token
Expand Down Expand Up @@ -186,9 +199,15 @@ def task_instances(self) -> TaskInstanceOperations:
@lru_cache() # type: ignore[misc]
@property
def connections(self) -> ConnectionOperations:
"""Operations related to TaskInstances."""
"""Operations related to Connections."""
return ConnectionOperations(self)

@lru_cache() # type: ignore[misc]
@property
def variables(self) -> VariableOperations:
"""Operations related to Variables."""
return VariableOperations(self)


# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
Expand Down
35 changes: 20 additions & 15 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@

from __future__ import annotations

from typing import Annotated, Any, Literal, Union
from typing import Annotated, Literal, Union

from pydantic import BaseModel, ConfigDict, Field

from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState # noqa: TCH001
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
TaskInstance,
TerminalTIState,
VariableResponse,
XComResponse,
)


class StartupDetails(BaseModel):
Expand All @@ -64,23 +70,22 @@ class StartupDetails(BaseModel):
type: Literal["StartupDetails"] = "StartupDetails"


class XComResponse(BaseModel):
class XComResult(XComResponse):
"""Response to ReadXCom request."""

key: str
value: Any
type: Literal["XComResponse"] = "XComResult"

type: Literal["XComResponse"] = "XComResponse"

class ConnectionResult(ConnectionResponse):
type: Literal["ConnectionResponse"] = "ConnectionResult"

class ConnectionResponse(BaseModel):
conn: Any

type: Literal["ConnectionResponse"] = "ConnectionResponse"
class VariableResult(VariableResponse):
type: Literal["VariableResponse"] = "VariableResult"


ToTask = Annotated[
Union[StartupDetails, XComResponse, ConnectionResponse],
Union[StartupDetails, XComResult, ConnectionResult, VariableResult],
Field(discriminator="type"),
]

Expand All @@ -98,22 +103,22 @@ class TaskState(BaseModel):
type: Literal["TaskState"] = "TaskState"


class ReadXCom(BaseModel):
class GetXCom(BaseModel):
key: str
type: Literal["ReadXCom"] = "ReadXCom"
type: Literal["ReadXCom"] = "GetXCom"


class GetConnection(BaseModel):
id: str
conn_id: str
type: Literal["GetConnection"] = "GetConnection"


class GetVariable(BaseModel):
id: str
key: str
type: Literal["GetVariable"] = "GetVariable"


ToSupervisor = Annotated[
Union[TaskState, ReadXCom, GetConnection, GetVariable],
Union[TaskState, GetXCom, GetConnection, GetVariable],
Field(discriminator="type"),
]
16 changes: 9 additions & 7 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
from airflow.sdk.api.client import Client
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
ConnectionResponse,
GetConnection,
GetVariable,
StartupDetails,
ToSupervisor,
)
Expand Down Expand Up @@ -480,7 +480,6 @@ def __repr__(self) -> str:
return rep + " >"

def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]:
encoder = ConnectionResponse.model_dump_json
# Use a buffer to avoid small allocations
buffer = bytearray(64)

Expand All @@ -495,17 +494,20 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
log.exception("Unable to decode message", line=line)
continue

# if isinstnace(msg, TaskState):
# if isinstance(msg, TaskState):
# self._terminal_state = msg.state
# elif isinstance(msg, ReadXCom):
# resp = XComResponse(key="secret", value=True)
# encoder.encode_into(resp, buffer)
# self.stdin.write(buffer + b"\n")
if isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.id)
resp = ConnectionResponse(conn=conn)
encoded_resp = encoder(resp)
buffer.extend(encoded_resp.encode())
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True)
buffer.extend(resp.encode())
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True)
buffer.extend(resp.encode())
else:
log.error("Unhandled request", msg=msg)
continue
Expand Down
11 changes: 10 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ def send_request(self, log: Logger, msg: ToSupervisor):
self.request_socket.write(encoded_msg)


# This global variable will be used by Connection/Variable classes etc to send requests to
# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution,
# to send requests back to the supervisor process.
#
# Why it needs to be a global:
# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests
# to the parent process during task execution.
# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the
# deeply nested execution stack.
# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
# accessible wherever needed during task execution without modifying every layer of the call stack.
SUPERVISOR_COMMS: CommsDecoder

# State machine!
Expand Down
55 changes: 54 additions & 1 deletion task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import httpx
import pytest

from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError
from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError, VariableOperations
from airflow.sdk.api.datamodels._generated import VariableResponse


class TestClient:
Expand Down Expand Up @@ -74,3 +75,55 @@ def handle_request(request: httpx.Request) -> httpx.Response:
client.get("http://error")
assert err.value.args == ("Not found",)
assert err.value.detail is None


def get_client(transport: httpx.MockTransport) -> Client:
"""Get a client with a custom transport"""
return Client(base_url="test://server", token="", transport=transport)


class TestVariableOperations:
def test_variable_get_success(self):
# Simulate a successful response from the server with a variable
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=200,
json={"key": "test_key", "value": "test_value"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = get_client(transport=httpx.MockTransport(handle_request))
result = VariableOperations(client=client).get(key="test_key")

assert isinstance(result, VariableResponse)
assert result.key == "test_key"
assert result.value == "test_value"

def test_variable_not_found(self):
# Simulate a 404 response from the server
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/non_existent_var":
return httpx.Response(
status_code=404,
json={
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = get_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError) as err:
VariableOperations(client=client).get(key="non_existent_var")

assert err.value.response.status_code == 404
assert err.value.detail == {
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
}
Loading

0 comments on commit 0175650

Please sign in to comment.