Skip to content

Commit

Permalink
fix: make webhook api call honor webhook component as input (langflow…
Browse files Browse the repository at this point in the history
…-ai#2511)

* refactor(base.py): refactor logic to find start_component_id based on multiple keywords for improved flexibility and readability

* feat(schema.py): add WebhookInput component type to INPUT_COMPONENTS list for handling webhook inputs in the graph schema

* refactor(base.py): refactor logic to determine start_component_id based on webhook or chat component presence in input vertices

* refactor: prioritize webhook component for determining start_component_id

* feat(utils.py): add function find_start_component_id to find component ID based on priority list of input types

* refactor(graph/base.py): refactor logic to find start component id in Graph class for better readability and maintainability

* test(test_webhook.py): override pytest fixture to check for OpenAI API key in environment variables before running tests

* test(test_webhook.py): update webhook json

* feat(schema.py): update WebhookInput component type name

* refactor: log package run telemetry in simplified_run_flow

* test: add test for webhook flow on run endpoint

* refactor(graph/base.py): skip unbuilt vertices when getting vertex outputs in Graph class

* refactor: simplify data_input assignment in LCTextSplitterComponent

* refactor: remove unused build method in CharacterTextSplitterComponent

* refactor: update imports in CharacterTextSplitter.py
  • Loading branch information
ogabrielluiz committed Jul 9, 2024
1 parent b6bf452 commit df306d3
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 265 deletions.
11 changes: 8 additions & 3 deletions src/backend/base/langflow/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,13 @@ async def simplified_run_flow(
return result

except ValueError as exc:
end_time = time.perf_counter()
background_tasks.add_task(
telemetry_service.log_package_run,
RunPayload(
runIsWebhook=False, runSeconds=int(end_time - start_time), runSuccess=False, runErrorMessage=str(exc)
runIsWebhook=False,
runSeconds=int(time.perf_counter() - start_time),
runSuccess=False,
runErrorMessage=str(exc),
),
)
if "badly formed hexadecimal UUID string" in str(exc):
Expand All @@ -234,7 +236,10 @@ async def simplified_run_flow(
background_tasks.add_task(
telemetry_service.log_package_run,
RunPayload(
runIsWebhook=False, runSeconds=int(end_time - start_time), runSuccess=False, runErrorMessage=str(exc)
runIsWebhook=False,
runSeconds=int(time.perf_counter() - start_time),
runSuccess=False,
runErrorMessage=str(exc),
),
)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/base/textsplitters/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any
from langchain_text_splitters import TextSplitter

from langchain_text_splitters import TextSplitter

from langflow.custom import Component
from langflow.io import Output
Expand Down Expand Up @@ -29,7 +29,7 @@ def split_data(self) -> list[Data]:
documents = []

if not isinstance(data_input, list):
data_input: list[Any] = [data_input]
data_input = [data_input]

for _input in data_input:
if isinstance(_input, Data):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import List, Any
from typing import Any

from langchain_text_splitters import CharacterTextSplitter, TextSplitter

from langflow.base.textsplitters.model import LCTextSplitterComponent
from langflow.inputs import IntInput, DataInput, MessageTextInput
from langflow.schema import Data
from langflow.inputs import DataInput, IntInput, MessageTextInput
from langflow.utils.util import unescape_string


Expand Down Expand Up @@ -53,27 +52,3 @@ def build_text_splitter(self) -> TextSplitter:
chunk_size=self.chunk_size,
separator=separator,
)

def build(
self,
inputs: List[Data],
chunk_overlap: int = 200,
chunk_size: int = 1000,
separator: str = "\n",
) -> List[Data]:
# separator may come escaped from the frontend
separator = unescape_string(separator)
documents = []
for _input in inputs:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
docs = CharacterTextSplitter(
chunk_overlap=chunk_overlap,
chunk_size=chunk_size,
separator=separator,
).split_documents(documents)
data = self.to_data(docs)
self.status = data
return data
9 changes: 5 additions & 4 deletions src/backend/base/langflow/graph/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
from langflow.graph.graph.state_manager import GraphStateManager
from langflow.graph.graph.utils import process_flow
from langflow.graph.graph.utils import find_start_component_id, process_flow
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import InterfaceVertex, StateVertex
Expand Down Expand Up @@ -335,9 +335,8 @@ async def _run(
logger.exception(exc)

try:
start_component_id = next(
(vertex_id for vertex_id in self._is_input_vertices if "chat" in vertex_id.lower()), None
)
# Prioritize the webhook component if it exists
start_component_id = find_start_component_id(self._is_input_vertices)
await self.process(start_component_id=start_component_id, fallback_to_env_vars=fallback_to_env_vars)
self.increment_run_count()
except Exception as exc:
Expand All @@ -350,6 +349,8 @@ async def _run(
# Get the outputs
vertex_outputs = []
for vertex in self.vertices:
if not vertex._built:
continue
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")

Expand Down
21 changes: 20 additions & 1 deletion src/backend/base/langflow/graph/graph/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
from collections import deque
import copy
from collections import deque

PRIORITY_LIST_OF_INPUTS = ["webhook", "chat"]


def find_start_component_id(vertices):
"""
Finds the component ID from a list of vertices based on a priority list of input types.
Args:
vertices (list): A list of vertex IDs.
Returns:
str or None: The component ID that matches the highest priority input type, or None if no match is found.
"""
for input_type_str in PRIORITY_LIST_OF_INPUTS:
component_id = next((vertex_id for vertex_id in vertices if input_type_str in vertex_id.lower()), None)
if component_id:
return component_id
return None


def find_last_node(nodes, edges):
Expand Down
2 changes: 2 additions & 0 deletions src/backend/base/langflow/graph/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class InterfaceComponentTypes(str, Enum, metaclass=ContainsEnumMeta):
TextInput = "TextInput"
TextOutput = "TextOutput"
DataOutput = "DataOutput"
WebhookInput = "Webhook"

def __contains__(cls, item):
try:
Expand All @@ -69,6 +70,7 @@ def __contains__(cls, item):
INPUT_COMPONENTS = [
InterfaceComponentTypes.ChatInput,
InterfaceComponentTypes.TextInput,
InterfaceComponentTypes.WebhookInput,
]
OUTPUT_COMPONENTS = [
InterfaceComponentTypes.ChatOutput,
Expand Down
229 changes: 1 addition & 228 deletions tests/data/WebhookTest.json

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions tests/test_webhook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import tempfile
from pathlib import Path

import pytest


@pytest.fixture(autouse=True)
def check_openai_api_key_in_environment_variables():
pass


def test_webhook_endpoint(client, added_webhook_test):
# The test is as follows:
Expand Down Expand Up @@ -28,6 +35,18 @@ def test_webhook_endpoint(client, added_webhook_test):
assert not file_path.exists()


def test_webhook_flow_on_run_endpoint(client, added_webhook_test, created_api_key):
endpoint_name = added_webhook_test["endpoint_name"]
endpoint = f"api/v1/run/{endpoint_name}?stream=false"
# Just test that "Random Payload" returns 202
# returns 202
payload = {
"output_type": "any",
}
response = client.post(endpoint, headers={"x-api-key": created_api_key.api_key}, json=payload)
assert response.status_code == 200, response.json()


def test_webhook_with_random_payload(client, added_webhook_test):
endpoint_name = added_webhook_test["endpoint_name"]
endpoint = f"api/v1/webhook/{endpoint_name}"
Expand Down

0 comments on commit df306d3

Please sign in to comment.