Skip to content

Commit

Permalink
core[patch]: Fixed trim functions, and added corresponding unit test …
Browse files Browse the repository at this point in the history
…for the solved issue (#28429)

- **Description:** 
- Trim functions were incorrectly deleting nodes with more than 1
outgoing/incoming edge, so an extra condition was added to check for
this directly. A unit test "test_trim_multi_edge" was written to test
this test case specifically.
- **Issue:** 
  - Fixes #28411 
  - Fixes langchain-ai/langgraph#1676
- **Dependencies:** 
  - No changes were made to the dependencies

- [x] Unit tests were added to verify the changes.
- [x] Updated documentation where necessary.
- [x] Ran make format, make lint, and make test to ensure compliance
with project standards.

---------

Co-authored-by: Tasif Hussain <[email protected]>
  • Loading branch information
fazam0616 and Tasif1 authored Dec 9, 2024
1 parent 54fba7e commit 481c4bf
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
12 changes: 10 additions & 2 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,14 +470,22 @@ def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge,
i.e., if removing it would not leave the graph without a "first" node."""
first_node = self.first_node()
if first_node and _first_node(self, exclude=[first_node.id]):
if (
first_node
and _first_node(self, exclude=[first_node.id])
and len({e for e in self.edges if e.source == first_node.id}) == 1
):
self.remove_node(first_node)

def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge,
i.e., if removing it would not leave the graph without a "last" node."""
last_node = self.last_node()
if last_node and _last_node(self, exclude=[last_node.id]):
if (
last_node
and _last_node(self, exclude=[last_node.id])
and len({e for e in self.edges if e.target == last_node.id}) == 1
):
self.remove_node(last_node)

def draw_ascii(self) -> str:
Expand Down
20 changes: 20 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,26 @@ class Schema(BaseModel):
assert graph.last_node() is end


def test_trim_multi_edge() -> None:
class Scheme(BaseModel):
a: str

graph = Graph()
start = graph.add_node(Scheme, id="__start__")
a = graph.add_node(Scheme, id="a")
last = graph.add_node(Scheme, id="__end__")

graph.add_edge(start, a)
graph.add_edge(a, last)
graph.add_edge(start, last)

graph.trim_first_node() # should not remove __start__ since it has 2 outgoing edges
assert graph.first_node() is start

graph.trim_last_node() # should not remove the __end__ node since it has 2 incoming edges
assert graph.last_node() is last


def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"])
prompt = PromptTemplate.from_template("Hello, {name}!")
Expand Down

0 comments on commit 481c4bf

Please sign in to comment.