diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 85bb077ee9..35fed20ab1 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -465,7 +465,7 @@ def __init__( self._is_execution_resuming: bool = False # Initialize state with initial values - self._state = self._create_initial_state() + self._state = self._create_initial_state(kwargs) self.tracing = tracing if ( is_tracing_enabled() @@ -474,9 +474,6 @@ def __init__( ): trace_listener = TraceCollectionListener() trace_listener.setup_listeners(crewai_event_bus) - # Apply any additional kwargs - if kwargs: - self._initialize_state(kwargs) crewai_event_bus.emit( self, @@ -502,23 +499,27 @@ def __init__( method = method.__get__(self, self.__class__) self._methods[method_name] = method - def _create_initial_state(self) -> T: + def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T: """Create and initialize flow state with UUID and default values. + If kwargs are provided, use them to initialize the state. - Returns: - New state instance with UUID and default values initialized + Args: + kwargs: Dictionary of state values to set/update (id field is ignored if present) Raises: ValueError: If structured state model lacks 'id' field TypeError: If state is neither BaseModel nor dictionary """ + kwargs = kwargs or {} + # prevent overriding the auto-generated ID + kwargs.pop("id", None) + # Handle case where initial_state is None but we have a type parameter if self.initial_state is None and hasattr(self, "_initial_state_t"): state_type = self._initial_state_t if isinstance(state_type, type): if issubclass(state_type, FlowState): - # Create instance without id, then set it - instance = state_type() + instance = state_type(**kwargs) if not hasattr(instance, "id"): instance.id = str(uuid4()) return cast(T, instance) @@ -527,33 +528,38 @@ def _create_initial_state(self) -> T: class StateWithId(state_type, FlowState): # type: ignore pass - instance = StateWithId() + instance = StateWithId(**kwargs) if not hasattr(instance, "id"): instance.id = str(uuid4()) return cast(T, instance) if state_type is dict: - return cast(T, {"id": str(uuid4())}) + return cast(T, {"id": str(uuid4()), **kwargs}) # Handle case where no initial state is provided if self.initial_state is None: - return cast(T, {"id": str(uuid4())}) + return cast(T, {"id": str(uuid4()), **kwargs}) # Handle case where initial_state is a type (class) if isinstance(self.initial_state, type): if issubclass(self.initial_state, FlowState): - return cast(T, self.initial_state()) # Uses model defaults + # Uses model defaults + kwargs + return cast(T, self.initial_state(**kwargs)) if issubclass(self.initial_state, BaseModel): # Validate that the model has an id field model_fields = getattr(self.initial_state, "model_fields", None) if not model_fields or "id" not in model_fields: raise ValueError("Flow state model must have an 'id' field") - return cast(T, self.initial_state()) # Uses model defaults + # Uses model defaults + kwargs + return cast(T, self.initial_state(**kwargs)) if self.initial_state is dict: - return cast(T, {"id": str(uuid4())}) + return cast(T, {"id": str(uuid4()), **kwargs}) # Handle dictionary instance case if isinstance(self.initial_state, dict): - new_state = dict(self.initial_state) # Copy to avoid mutations + # Copy to avoid mutations + new_state = dict(self.initial_state) + # Apply kwargs + new_state.update(kwargs) if "id" not in new_state: new_state["id"] = str(uuid4()) return cast(T, new_state) @@ -577,6 +583,9 @@ class StateWithId(state_type, FlowState): # type: ignore k: v for k, v in model.__dict__.items() if not k.startswith("_") } + # Merge kwargs into state_dict + state_dict.update(kwargs) + # Create new instance of the same class model_class = type(model) return cast(T, model_class(**state_dict)) diff --git a/tests/test_flow.py b/tests/test_flow.py index 504cf8e6e9..325be05a9f 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -4,17 +4,17 @@ from datetime import datetime import pytest -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError -from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import ( FlowFinishedEvent, - FlowStartedEvent, FlowPlotEvent, + FlowStartedEvent, MethodExecutionFinishedEvent, MethodExecutionStartedEvent, ) +from crewai.flow.flow import Flow, FlowState, and_, listen, or_, router, start def test_simple_sequential_flow(): @@ -275,6 +275,156 @@ def step_2(self): assert flow.counter == 2 +def test_flow_with_required_pydantic_fields(): + """Test a flow with required Pydantic fields passed as kwargs.""" + + class TestState(BaseModel): + field_1: str + field_2: int + + class TestFlow(Flow[TestState]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.counter = 0 + + @start() + def step_1(self): + self.counter += 1 + + @listen(step_1) + def step_2(self): + self.counter *= 2 + assert self.counter == 2 + + flow = TestFlow(field_1="ABC", field_2=1) + flow.kickoff() + + assert flow.state.field_1 == "ABC" + assert flow.state.field_2 == 1 + assert flow.counter == 2 + + with pytest.raises(ValidationError) as exc_info: + flow = TestFlow() + + assert "field_1" in str(exc_info.value) and "field_2" in str(exc_info.value) + + +def test_flow_with_required_pydantic_fields_and_kickoff_inputs(): + """Test flow with required fields in __init__ and additional inputs in kickoff.""" + + class TestState(BaseModel): + field_1: str = "ABC" + field_2: int + + class TestFlow(Flow[TestState]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.counter = 0 + + @start() + def step_1(self): + self.counter += 1 + + @listen(step_1) + def step_2(self): + self.counter *= 2 + assert self.counter == 2 + + flow = TestFlow(field_2=1) + assert flow.state.field_1 == "ABC" + + flow.kickoff(inputs={"field_1": "CBA"}) + + assert flow.state.field_1 == "CBA" + assert flow.state.field_2 == 1 + assert flow.counter == 2 + + +def test_flow_with_flow_state_subclass(): + """Test a flow with FlowState subclass and required fields passed as kwargs.""" + + class TestState(FlowState): + field_1: str + + class TestFlow(Flow[TestState]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.counter = 0 + + @start() + def step_1(self): + self.counter += 1 + + @listen(step_1) + def step_2(self): + self.counter *= 2 + assert self.counter == 2 + + flow = TestFlow(field_1="ABC") + flow.kickoff() + + assert flow.state.field_1 == "ABC" + assert flow.counter == 2 + + with pytest.raises(ValidationError) as exc_info: + flow = TestFlow() + + assert "field_1" in str(exc_info.value) + + +def test_flow_ignore_id(): + """Test a flow where initial id value is ignored when passed as kwarg.""" + + class TestState(FlowState): + field_1: str + + class TestFlow(Flow[TestState]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.counter = 0 + + @start() + def step_1(self): + self.counter += 1 + + @listen(step_1) + def step_2(self): + self.counter *= 2 + assert self.counter == 2 + + flow = TestFlow(id="test_id", field_1="ABC") + flow.kickoff() + + assert flow.state.id != "test_id" + assert flow.state.field_1 == "ABC" + assert flow.counter == 2 + + +def test_flow_without_initial_state(): + """Test a flow init with state fields passed as kwargs.""" + + class TestFlow(Flow): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.counter = 0 + + @start() + def step_1(self): + self.counter += 1 + + @listen(step_1) + def step_2(self): + self.counter *= 2 + assert self.counter == 2 + + flow = TestFlow(field_1="ABC") + flow.kickoff() + + assert isinstance(flow.state, dict) + assert flow.state.get("field_1") == "ABC" + assert flow.counter == 2 + + def test_flow_uuid_unstructured(): """Test that unstructured (dictionary) flow states automatically get a UUID that persists.""" initial_id = None @@ -679,11 +829,11 @@ def handle_flow_end(source, event): assert isinstance(received_events[3], MethodExecutionStartedEvent) assert received_events[3].method_name == "send_welcome_message" assert received_events[3].params == {} - assert getattr(received_events[3].state, "sent") is False + assert getattr(received_events[3].state, "sent") is False # noqa: B009 assert isinstance(received_events[4], MethodExecutionFinishedEvent) assert received_events[4].method_name == "send_welcome_message" - assert getattr(received_events[4].state, "sent") is True + assert getattr(received_events[4].state, "sent") is True # noqa: B009 assert received_events[4].result == "Welcome, Anakin!" assert isinstance(received_events[5], FlowFinishedEvent)