Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions src/crewai/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def __init__(
self._is_execution_resuming: bool = False

# Initialize state with initial values
self._state = self._create_initial_state()
self._state = self._create_initial_state(kwargs)
self.tracing = tracing
if (
is_tracing_enabled()
Expand All @@ -474,9 +474,6 @@ def __init__(
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)

crewai_event_bus.emit(
self,
Expand All @@ -502,23 +499,27 @@ def __init__(
method = method.__get__(self, self.__class__)
self._methods[method_name] = method

def _create_initial_state(self) -> T:
def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T:
"""Create and initialize flow state with UUID and default values.
If kwargs are provided, use them to initialize the state.

Returns:
New state instance with UUID and default values initialized
Args:
kwargs: Dictionary of state values to set/update (id field is ignored if present)

Raises:
ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary
"""
kwargs = kwargs or {}
# prevent overriding the auto-generated ID
kwargs.pop("id", None)

# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_t"):
state_type = self._initial_state_t
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
instance = state_type(**kwargs)
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast(T, instance)
Expand All @@ -527,33 +528,38 @@ def _create_initial_state(self) -> T:
class StateWithId(state_type, FlowState): # type: ignore
pass

instance = StateWithId()
instance = StateWithId(**kwargs)
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast(T, instance)
if state_type is dict:
return cast(T, {"id": str(uuid4())})
return cast(T, {"id": str(uuid4()), **kwargs})

# Handle case where no initial state is provided
if self.initial_state is None:
return cast(T, {"id": str(uuid4())})
return cast(T, {"id": str(uuid4()), **kwargs})

# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state()) # Uses model defaults
# Uses model defaults + kwargs
return cast(T, self.initial_state(**kwargs))
if issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
# Uses model defaults + kwargs
return cast(T, self.initial_state(**kwargs))
if self.initial_state is dict:
return cast(T, {"id": str(uuid4())})
return cast(T, {"id": str(uuid4()), **kwargs})

# Handle dictionary instance case
if isinstance(self.initial_state, dict):
new_state = dict(self.initial_state) # Copy to avoid mutations
# Copy to avoid mutations
new_state = dict(self.initial_state)
# Apply kwargs
new_state.update(kwargs)
if "id" not in new_state:
new_state["id"] = str(uuid4())
return cast(T, new_state)
Expand All @@ -577,6 +583,9 @@ class StateWithId(state_type, FlowState): # type: ignore
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}

# Merge kwargs into state_dict
state_dict.update(kwargs)

# Create new instance of the same class
model_class = type(model)
return cast(T, model_class(**state_dict))
Expand Down
160 changes: 155 additions & 5 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from datetime import datetime

import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import (
FlowFinishedEvent,
FlowStartedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.flow import Flow, FlowState, and_, listen, or_, router, start


def test_simple_sequential_flow():
Expand Down Expand Up @@ -275,6 +275,156 @@ def step_2(self):
assert flow.counter == 2


def test_flow_with_required_pydantic_fields():
"""Test a flow with required Pydantic fields passed as kwargs."""

class TestState(BaseModel):
field_1: str
field_2: int

class TestFlow(Flow[TestState]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0

@start()
def step_1(self):
self.counter += 1

@listen(step_1)
def step_2(self):
self.counter *= 2
assert self.counter == 2

flow = TestFlow(field_1="ABC", field_2=1)
flow.kickoff()

assert flow.state.field_1 == "ABC"
assert flow.state.field_2 == 1
assert flow.counter == 2

with pytest.raises(ValidationError) as exc_info:
flow = TestFlow()

assert "field_1" in str(exc_info.value) and "field_2" in str(exc_info.value)


def test_flow_with_required_pydantic_fields_and_kickoff_inputs():
"""Test flow with required fields in __init__ and additional inputs in kickoff."""

class TestState(BaseModel):
field_1: str = "ABC"
field_2: int

class TestFlow(Flow[TestState]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0

@start()
def step_1(self):
self.counter += 1

@listen(step_1)
def step_2(self):
self.counter *= 2
assert self.counter == 2

flow = TestFlow(field_2=1)
assert flow.state.field_1 == "ABC"

flow.kickoff(inputs={"field_1": "CBA"})

assert flow.state.field_1 == "CBA"
assert flow.state.field_2 == 1
assert flow.counter == 2


def test_flow_with_flow_state_subclass():
"""Test a flow with FlowState subclass and required fields passed as kwargs."""

class TestState(FlowState):
field_1: str

class TestFlow(Flow[TestState]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0

@start()
def step_1(self):
self.counter += 1

@listen(step_1)
def step_2(self):
self.counter *= 2
assert self.counter == 2

flow = TestFlow(field_1="ABC")
flow.kickoff()

assert flow.state.field_1 == "ABC"
assert flow.counter == 2

with pytest.raises(ValidationError) as exc_info:
flow = TestFlow()

assert "field_1" in str(exc_info.value)


def test_flow_ignore_id():
"""Test a flow where initial id value is ignored when passed as kwarg."""

class TestState(FlowState):
field_1: str

class TestFlow(Flow[TestState]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0

@start()
def step_1(self):
self.counter += 1

@listen(step_1)
def step_2(self):
self.counter *= 2
assert self.counter == 2

flow = TestFlow(id="test_id", field_1="ABC")
flow.kickoff()

assert flow.state.id != "test_id"
assert flow.state.field_1 == "ABC"
assert flow.counter == 2


def test_flow_without_initial_state():
"""Test a flow init with state fields passed as kwargs."""

class TestFlow(Flow):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.counter = 0

@start()
def step_1(self):
self.counter += 1

@listen(step_1)
def step_2(self):
self.counter *= 2
assert self.counter == 2

flow = TestFlow(field_1="ABC")
flow.kickoff()

assert isinstance(flow.state, dict)
assert flow.state.get("field_1") == "ABC"
assert flow.counter == 2


def test_flow_uuid_unstructured():
"""Test that unstructured (dictionary) flow states automatically get a UUID that persists."""
initial_id = None
Expand Down Expand Up @@ -679,11 +829,11 @@ def handle_flow_end(source, event):
assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {}
assert getattr(received_events[3].state, "sent") is False
assert getattr(received_events[3].state, "sent") is False # noqa: B009

assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message"
assert getattr(received_events[4].state, "sent") is True
assert getattr(received_events[4].state, "sent") is True # noqa: B009
assert received_events[4].result == "Welcome, Anakin!"

assert isinstance(received_events[5], FlowFinishedEvent)
Expand Down