Skip to content

Commit

Permalink
Deduplicate searches in normal mode & across research iterations
Browse files Browse the repository at this point in the history
- Deduplicate online, doc search queries across research iterations.
  This avoids running previously run online, doc searches again and
  dedupes online, doc context seen by model to generate response.
- Deduplicate online search queries generated by chat model for each
  user query.
- Do not pass online, docs, code context separately when generate
  response in research mode. These are already collected in the meta
  research passed with the user query
- Improve formatting of context passed to generate research response
  - Use xml tags to delimit context. Pass per iteration queries in each
    iteration result
  - Put user query before meta research results in user message passed
    for generating response

This deduplications will improve speed, cost & quality of research mode
  • Loading branch information
debanjum committed Nov 11, 2024
1 parent 306f7a2 commit 137687e
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 42 deletions.
2 changes: 2 additions & 0 deletions src/khoj/processor/conversation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ def __init__(
onlineContext: dict = None,
codeContext: dict = None,
summarizedResult: str = None,
warning: str = None,
):
self.tool = tool
self.query = query
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.summarizedResult = summarizedResult
self.warning = warning


def construct_iteration_history(
Expand Down
30 changes: 19 additions & 11 deletions src/khoj/processor/tools/online_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import urllib.parse
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import aiohttp
from bs4 import BeautifulSoup
Expand Down Expand Up @@ -66,6 +66,7 @@ async def search_online(
custom_filters: List[str] = [],
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_images: List[str] = None,
previous_subqueries: Set = set(),
agent: Agent = None,
tracer: dict = {},
):
Expand All @@ -76,19 +77,24 @@ async def search_online(
return

# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
new_subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
)
response_dict = {}
subqueries = list(new_subqueries - previous_subqueries)
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}

if subqueries:
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
if is_none_or_empty(subqueries):
logger.info("No new subqueries to search online")
yield response_dict
return

logger.info(f"🌐 Searching the Internet for {subqueries}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(subqueries)
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}

with timer(f"Internet searches for {list(subqueries)} took", logger):
with timer(f"Internet searches for {subqueries} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery, location) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks)
Expand Down Expand Up @@ -119,7 +125,9 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
read_webpage_and_extract_content(
data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
Expand Down
4 changes: 3 additions & 1 deletion src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import threading
import time
import uuid
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Set, Union

import cron_descriptor
import pytz
Expand Down Expand Up @@ -349,6 +349,7 @@ async def extract_references_and_questions(
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
tracer: dict = {},
):
Expand Down Expand Up @@ -477,6 +478,7 @@ async def extract_references_and_questions(
)

# Collate search results as context for GPT
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
Expand Down
3 changes: 2 additions & 1 deletion src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,8 @@ def collect_telemetry():
yield research_result

# researched_results = await extract_relevant_info(q, researched_results, agent)
logger.info(f"Researched Results: {researched_results}")
if state.verbose > 1:
logger.debug(f"Researched Results: {researched_results}")

used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []
Expand Down
21 changes: 13 additions & 8 deletions src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
Expand Down Expand Up @@ -494,7 +495,7 @@ async def generate_online_subqueries(
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
) -> Set[str]:
"""
Generate subqueries from the given query
"""
Expand Down Expand Up @@ -529,14 +530,14 @@ async def generate_online_subqueries(
try:
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, set) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
return {q}
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
return {q}


async def schedule_query(
Expand Down Expand Up @@ -1128,9 +1129,6 @@ def generate_chat_response(

metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
query_to_run = q
if meta_research:
query_to_run = f"AI Research: {meta_research} {q}"
try:
partial_completion = partial(
save_to_conversation_log,
Expand All @@ -1148,6 +1146,13 @@ def generate_chat_response(
train_of_thought=train_of_thought,
)

query_to_run = q
if meta_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>"
compiled_references = []
online_results = {}
code_results = {}

conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
if not vision_available and query_images:
Expand Down
68 changes: 47 additions & 21 deletions src/khoj/routers/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,35 @@ async def apick_next_tool(
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
previous_iterations_history: str = None,
previous_iterations: List[InformationCollectionIteration] = [],
max_iterations: int = 5,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
"""
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""

# Construct tool options for the agent to choose from
tool_options = dict()
tool_options_str = ""

agent_tools = agent.input_tools if agent else []

for tool, description in function_calling_description_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'

# Construct chat history with user and iteration history with researcher agent for context
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)

if query_images:
query = f"[placeholder for user attached images]\n{query}"

today = datetime.today()
location_data = f"{location}" if location else "Unknown"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)

# Extract Past User Message and Inferred Questions from Conversation Log
today = datetime.today()
location_data = f"{location}" if location else "Unknown"

function_planning_prompt = prompts.plan_function_execution.format(
tools=tool_options_str,
chat_history=chat_history,
Expand Down Expand Up @@ -112,8 +109,15 @@ async def apick_next_tool(
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
warning = None
logger.info(f"Response for determining relevant tools: {response}")
if send_status_func:

# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations}
if (selected_tool, generated_query) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration
elif send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
Expand All @@ -123,13 +127,14 @@ async def apick_next_tool(
yield InformationCollectionIteration(
tool=selected_tool,
query=generated_query,
warning=warning,
)

except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield InformationCollectionIteration(
tool=None,
query=None,
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
)


Expand All @@ -156,7 +161,6 @@ async def execute_information_collection(
document_results: List[Dict[str, str]] = []
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)

async for result in apick_next_tool(
query,
Expand All @@ -166,7 +170,7 @@ async def execute_information_collection(
location,
user_name,
agent,
previous_iterations_history,
previous_iterations,
MAX_ITERATIONS,
send_status_func,
tracer=tracer,
Expand All @@ -176,9 +180,16 @@ async def execute_information_collection(
elif isinstance(result, InformationCollectionIteration):
this_iteration = result

if this_iteration.tool == ConversationCommand.Notes:
# Skip running iteration if warning present in iteration
if this_iteration.warning:
logger.warning(f"Research mode: {this_iteration.warning}.")

elif this_iteration.tool == ConversationCommand.Notes:
this_iteration.context = []
document_results = []
previous_inferred_queries = {
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
}
async for result in extract_references_and_questions(
request,
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
Expand All @@ -190,6 +201,7 @@ async def execute_information_collection(
location,
send_status_func,
query_images,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
):
Expand All @@ -213,6 +225,12 @@ async def execute_information_collection(
logger.error(f"Error extracting document references: {e}", exc_info=True)

elif this_iteration.tool == ConversationCommand.Online:
previous_subqueries = {
subquery
for iteration in previous_iterations
if iteration.onlineContext
for subquery in iteration.onlineContext.keys()
}
async for result in search_online(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
Expand All @@ -222,11 +240,16 @@ async def execute_information_collection(
[],
max_webpages_to_read=0,
query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif is_none_or_empty(result):
this_iteration.warning = (
"Detected previously run online search queries. Skipping iteration. Try something different."
)
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
Expand Down Expand Up @@ -311,16 +334,19 @@ async def execute_information_collection(

current_iteration += 1

if document_results or online_results or code_results or summarize_files:
results_data = f"**Results**:\n"
if document_results or online_results or code_results or summarize_files or this_iteration.warning:
results_data = f"\n<iteration>{current_iteration}\n<tool>{this_iteration.tool}</tool>\n<query>{this_iteration.query}</query>\n<results>"
if document_results:
results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
if online_results:
results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if code_results:
results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<code_results>\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if summarize_files:
results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += "\n</results>\n</iteration>"

# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data
Expand Down

0 comments on commit 137687e

Please sign in to comment.