Skip to content

Commit

Permalink
feat: ui build in one single http request (#3020)
Browse files Browse the repository at this point in the history
* feat: ui build in one single http request

* fix use session_id

* fix frozen

* [autofix.ci] apply automated fixes

* prettier

* add tests

* add tests

* fix mypy

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
nicoloboschi and autofix-ci[bot] authored Aug 2, 2024
1 parent 51e0829 commit f311a6d
Show file tree
Hide file tree
Showing 12 changed files with 707 additions and 46 deletions.
22 changes: 16 additions & 6 deletions src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
import warnings
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Dict

from fastapi import HTTPException
from sqlmodel import Session
Expand Down Expand Up @@ -122,12 +122,9 @@ def format_elapsed_time(elapsed_time: float) -> str:
return f"{minutes} {minutes_unit}, {seconds} {seconds_unit}"


async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
async def build_graph_from_data(flow_id: str, payload: Dict, **kwargs):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
graph = Graph.from_payload(flow.data, flow_id, flow_name=flow.name, user_id=str(flow.user_id))
graph = Graph.from_payload(payload, flow_id, **kwargs)
for vertex_id in graph._has_session_id_vertices:
vertex = graph.get_vertex(vertex_id)
if vertex is None:
Expand All @@ -139,6 +136,19 @@ async def build_graph_from_db(flow_id: str, session: Session, chat_service: "Cha
graph.set_run_id(run_id)
graph.set_run_name()
await graph.initialize_run()
return graph


async def build_graph_from_db_no_cache(flow_id: str, session: Session):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id))


async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
graph = await build_graph_from_db_no_cache(flow_id, session)
await chat_service.set_cache(flow_id, graph)
return graph

Expand Down
298 changes: 298 additions & 0 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import asyncio
import json
import time
import traceback
import typing
import uuid
from typing import TYPE_CHECKING, Annotated, Optional

from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from starlette.background import BackgroundTask
from starlette.responses import ContentStream
from starlette.types import Receive

from langflow.api.utils import (
build_and_cache_graph_from_data,
Expand All @@ -14,6 +20,8 @@
format_exception_message,
get_top_level_vertices,
parse_exception,
build_graph_from_db_no_cache,
build_graph_from_data,
)
from langflow.api.v1.schemas import (
FlowDataRequest,
Expand Down Expand Up @@ -140,6 +148,296 @@ async def retrieve_vertices_order(
raise HTTPException(status_code=500, detail=str(exc)) from exc


@router.post("/build/{flow_id}/flow")
async def build_flow(
background_tasks: BackgroundTasks,
flow_id: uuid.UUID,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
data: Annotated[Optional[FlowDataRequest], Body(embed=True)] = None,
files: Optional[list[str]] = None,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
chat_service: "ChatService" = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
session=Depends(get_session),
):
async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
start_time = time.perf_counter()
components_count = None
try:
flow_id_str = str(flow_id)
if not data:
graph = await build_graph_from_db_no_cache(flow_id=flow_id_str, session=session)
else:
graph = await build_graph_from_data(flow_id_str, data.model_dump())
graph.validate_stream()
if stop_component_id or start_component_id:
try:
first_layer = graph.sort_vertices(stop_component_id, start_component_id)
except Exception as exc:
logger.error(exc)
first_layer = graph.sort_vertices()
else:
first_layer = graph.sort_vertices()

for vertex_id in first_layer:
graph.run_manager.add_to_vertices_being_run(vertex_id)

# Now vertices is a list of lists
# We need to get the id of each vertex
# and return the same structure but only with the ids
components_count = len(graph.vertices)
vertices_to_run = list(graph.vertices_to_run.union(get_top_level_vertices(graph, graph.vertices_to_run)))
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playgroundSeconds=int(time.perf_counter() - start_time),
playgroundComponentCount=components_count,
playgroundSuccess=True,
),
)
return first_layer, vertices_to_run, graph
except Exception as exc:
background_tasks.add_task(
telemetry_service.log_package_playground,
PlaygroundPayload(
playgroundSeconds=int(time.perf_counter() - start_time),
playgroundComponentCount=components_count,
playgroundSuccess=False,
playgroundErrorMessage=str(exc),
),
)
if "stream or streaming set to True" in str(exc):
raise HTTPException(status_code=400, detail=str(exc))
logger.error(f"Error checking build status: {exc}")
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc

async def _build_vertex(vertex_id: str, graph: "Graph") -> VertexBuildResponse:
flow_id_str = str(flow_id)

next_runnable_vertices = []
top_level_vertices = []
start_time = time.perf_counter()
error_message = None
try:
vertex = graph.get_vertex(vertex_id)
try:
lock = chat_service._async_cache_locks[flow_id_str]
(
result_dict,
params,
valid,
artifacts,
vertex,
) = await graph.build_vertex(
chat_service=None,
vertex_id=vertex_id,
user_id=current_user.id,
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
)
next_runnable_vertices = await graph.get_next_runnable_vertices(lock, vertex=vertex, cache=False)
top_level_vertices = graph.get_top_level_vertices(next_runnable_vertices)

result_data_response = ResultDataResponse.model_validate(result_dict, from_attributes=True)
except Exception as exc:
if isinstance(exc, ComponentBuildException):
params = exc.message
tb = exc.formatted_traceback
else:
tb = traceback.format_exc()
logger.exception(f"Error building Component: {exc}")
params = format_exception_message(exc)
message = {"errorMessage": params, "stackTrace": tb}
valid = False
error_message = params
output_label = vertex.outputs[0]["name"] if vertex.outputs else "output"
outputs = {output_label: OutputValue(message=message, type="error")}
result_data_response = ResultDataResponse(results={}, outputs=outputs)
artifacts = {}
background_tasks.add_task(graph.end_all_traces, error=exc)

result_data_response.message = artifacts

# Log the vertex build
if not vertex.will_stream:
background_tasks.add_task(
log_vertex_build,
flow_id=flow_id_str,
vertex_id=vertex_id.split("-")[0],
valid=valid,
params=params,
data=result_data_response,
artifacts=artifacts,
)

timedelta = time.perf_counter() - start_time
duration = format_elapsed_time(timedelta)
result_data_response.duration = duration
result_data_response.timedelta = timedelta
vertex.add_build_time(timedelta)
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
# graph.stop_vertex tells us if the user asked
# to stop the build of the graph at a certain vertex
# if it is in next_vertices_ids, we need to remove other
# vertices from next_vertices_ids
if graph.stop_vertex and graph.stop_vertex in next_runnable_vertices:
next_runnable_vertices = [graph.stop_vertex]

if not graph.run_manager.vertices_being_run and not next_runnable_vertices:
background_tasks.add_task(graph.end_all_traces)

build_response = VertexBuildResponse(
inactivated_vertices=list(set(inactivated_vertices)),
next_vertices_ids=list(set(next_runnable_vertices)),
top_level_vertices=list(set(top_level_vertices)),
valid=valid,
params=params,
id=vertex.id,
data=result_data_response,
)
background_tasks.add_task(
telemetry_service.log_package_component,
ComponentPayload(
componentName=vertex_id.split("-")[0],
componentSeconds=int(time.perf_counter() - start_time),
componentSuccess=valid,
componentErrorMessage=error_message,
),
)
return build_response
except Exception as exc:
background_tasks.add_task(
telemetry_service.log_package_component,
ComponentPayload(
componentName=vertex_id.split("-")[0],
componentSeconds=int(time.perf_counter() - start_time),
componentSuccess=False,
componentErrorMessage=str(exc),
),
)
logger.error(f"Error building Component: \n\n{exc}")
logger.exception(exc)
message = parse_exception(exc)
raise HTTPException(status_code=500, detail=message) from exc

def send_event(event_type: str, value: dict, queue: asyncio.Queue) -> None:
json_data = {"event": event_type, "data": value}
event_id = uuid.uuid4()
logger.debug(f"sending event {event_id}: {event_type}")
str_data = json.dumps(json_data) + "\n\n"
queue.put_nowait((event_id, str_data.encode("utf-8"), time.time()))

async def build_vertices(
vertex_id: str, graph: "Graph", queue: asyncio.Queue, client_consumed_queue: asyncio.Queue
) -> None:
build_task = asyncio.create_task(await asyncio.to_thread(_build_vertex, vertex_id, graph))
try:
await build_task
except asyncio.CancelledError:
build_task.cancel()
return

vertex_build_response: VertexBuildResponse = build_task.result()
# send built event or error event
send_event("end_vertex", {"build_data": json.loads(vertex_build_response.model_dump_json())}, queue)
await client_consumed_queue.get()
if vertex_build_response.valid:
if vertex_build_response.next_vertices_ids:
tasks = []
for next_vertex_id in vertex_build_response.next_vertices_ids:
task = asyncio.create_task(build_vertices(next_vertex_id, graph, queue, client_consumed_queue))
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return

async def event_generator(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> None:
if not data:
# using another thread since the DB query is I/O bound
vertices_task = asyncio.create_task(await asyncio.to_thread(build_graph_and_get_order))
try:
await vertices_task
except asyncio.CancelledError:
vertices_task.cancel()
return

ids, vertices_to_run, graph = vertices_task.result()
else:
ids, vertices_to_run, graph = await build_graph_and_get_order()
send_event("vertices_sorted", {"ids": ids, "to_run": vertices_to_run}, queue)
await client_consumed_queue.get()

tasks = []
for vertex_id in ids:
task = asyncio.create_task(build_vertices(vertex_id, graph, queue, client_consumed_queue))
tasks.append(task)
try:
await asyncio.gather(*tasks)
except asyncio.CancelledError:
for task in tasks:
task.cancel()
return
send_event("end", {}, queue)
await queue.put((None, None, time.time))

async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio.Queue) -> typing.AsyncGenerator:
while True:
event_id, value, put_time = await queue.get()
if value is None:
break
get_time = time.time()
yield value
get_time_yield = time.time()
client_consumed_queue.put_nowait(event_id)
logger.debug(
f"consumed event {str(event_id)} (time in queue, {get_time - put_time:.4f}, client {get_time_yield - get_time:.4f})"
)

asyncio_queue: asyncio.Queue = asyncio.Queue()
asyncio_queue_client_consumed: asyncio.Queue = asyncio.Queue()
main_task = asyncio.create_task(event_generator(asyncio_queue, asyncio_queue_client_consumed))

def on_disconnect():
logger.debug("Client disconnected, closing tasks")
main_task.cancel()

return DisconnectHandlerStreamingResponse(
consume_and_yield(asyncio_queue, asyncio_queue_client_consumed),
media_type="application/x-ndjson",
on_disconnect=on_disconnect,
)


class DisconnectHandlerStreamingResponse(StreamingResponse):
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
on_disconnect: Optional[typing.Callable] = None,
):
super().__init__(content, status_code, headers, media_type, background)
self.on_disconnect = on_disconnect

async def listen_for_disconnect(self, receive: Receive) -> None:
while True:
message = await receive()
if message["type"] == "http.disconnect":
if self.on_disconnect:
await self.on_disconnect()
break


@router.post("/build/{flow_id}/vertices/{vertex_id}")
async def build_vertex(
flow_id: uuid.UUID,
Expand Down
Loading

0 comments on commit f311a6d

Please sign in to comment.