diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index 813ac8a086..a2227a5cf9 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -112,6 +112,8 @@ def _invoke_loop(self, formatted_answer=None): try: while not isinstance(formatted_answer, AgentFinish): if not self.request_within_rpm_limit or self.request_within_rpm_limit(): + self._check_context_length_before_call() + answer = self.llm.call( self.messages, callbacks=self.callbacks, @@ -327,6 +329,19 @@ def _summarize_messages(self) -> None: ) ] + def _check_context_length_before_call(self) -> None: + total_chars = sum(len(msg.get("content", "")) for msg in self.messages) + estimated_tokens = total_chars // 4 + + context_window_size = self.llm.get_context_window_size() + + if estimated_tokens > context_window_size: + self._printer.print( + content=f"Estimated token count ({estimated_tokens}) exceeds context window ({context_window_size}). Handling proactively.", + color="yellow", + ) + self._handle_context_length() + def _handle_context_length(self) -> None: if self.respect_context_window: self._printer.print( diff --git a/tests/agent_test.py b/tests/agent_test.py index 6879a4519b..642c5b84e1 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1625,3 +1625,78 @@ def test_agent_with_knowledge_sources(): # Assert that the agent provides the correct information assert "red" in result.raw.lower() + + +def test_proactive_context_length_handling_prevents_empty_response(): + """Test that proactive context length checking prevents empty LLM responses.""" + agent = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + sliding_context_window=True, + ) + + long_input = "This is a very long input that should exceed the context window. " * 1000 + + with patch.object(agent.llm, 'get_context_window_size', return_value=100): + with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle: + with patch.object(agent.llm, 'call', return_value="Proper response after summarization"): + + agent.agent_executor.messages = [ + {"role": "user", "content": long_input} + ] + + task = Task( + description="Process this long input", + expected_output="A response", + agent=agent, + ) + + result = agent.execute_task(task) + + mock_handle.assert_called() + assert result and result.strip() != "" + + +def test_proactive_context_length_handling_with_no_summarization(): + """Test proactive context length checking when summarization is disabled.""" + agent = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + sliding_context_window=False, + ) + + long_input = "This is a very long input. " * 1000 + + with patch.object(agent.llm, 'get_context_window_size', return_value=100): + agent.agent_executor.messages = [ + {"role": "user", "content": long_input} + ] + + with pytest.raises(SystemExit): + agent.agent_executor._check_context_length_before_call() + + +def test_context_length_estimation(): + """Test the token estimation logic.""" + agent = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + ) + + agent.agent_executor.messages = [ + {"role": "user", "content": "Short message"}, + {"role": "assistant", "content": "Another short message"}, + ] + + with patch.object(agent.llm, 'get_context_window_size', return_value=10): + with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle: + agent.agent_executor._check_context_length_before_call() + mock_handle.assert_not_called() + + with patch.object(agent.llm, 'get_context_window_size', return_value=5): + with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle: + agent.agent_executor._check_context_length_before_call() + mock_handle.assert_called()