Skip to content

Commit

Permalink
fix multi retrieval with resource score issue (langgenius#1578)
Browse files Browse the repository at this point in the history
Co-authored-by: jyong <[email protected]>
  • Loading branch information
JohnJyong and JohnJyong authored Nov 21, 2023
1 parent 2853ba7 commit 4991606
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
7 changes: 6 additions & 1 deletion api/core/tool/dataset_multi_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def _run(self, query: str) -> str:

hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
hit_callback.on_tool_end(all_documents)
document_score_list = {}
for item in all_documents:
document_score_list[item.metadata['doc_id']] = item.metadata['score']

document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
Expand Down Expand Up @@ -120,8 +123,10 @@ def _run(self, query: str) -> str:
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': self.retriever_from
'retriever_from': self.retriever_from,
'score': document_score_list.get(segment.index_node_id, None)
}

if self.retriever_from == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
Expand Down
6 changes: 3 additions & 3 deletions api/core/tool/dataset_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def _run(self, query: str) -> str:
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': self.retriever_from
'retriever_from': self.retriever_from,
'score': document_score_list.get(segment.index_node_id, None)

}
if dataset.indexing_technique != "economy":
source['score'] = document_score_list.get(segment.index_node_id)
if self.retriever_from == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
Expand Down

0 comments on commit 4991606

Please sign in to comment.