Skip to content

Commit d127fec

Browse files
KKenny0taprosoft
andauthored
feat: support for visualizing citation results (via embeddings) (#461)
* feat:support for visualizing citation results (via embeddings) Signed-off-by: Kennywu <[email protected]> * fix: remove ktem dependency in visualize_cited * fix: limit onnx version for fastembed * fix: test case of indexing * fix: minor update * fix: chroma req * fix: chroma req --------- Signed-off-by: Kennywu <[email protected]> Co-authored-by: Tadashi <[email protected]>
1 parent bd2490b commit d127fec

File tree

4 files changed

+196
-5
lines changed

4 files changed

+196
-5
lines changed

libs/kotaemon/kotaemon/indices/vectorindex.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ def to_qa_pipeline(self, *args, **kwargs):
5353
def write_chunk_to_file(self, docs: list[Document]):
5454
# save the chunks content into markdown format
5555
if self.cache_dir:
56-
file_name = Path(docs[0].metadata["file_name"])
56+
file_name = docs[0].metadata.get("file_name")
57+
if not file_name:
58+
return
59+
60+
file_name = Path(file_name)
5761
for i in range(len(docs)):
5862
markdown_content = ""
5963
if "page_label" in docs[i].metadata:

libs/kotaemon/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
"langchain-cohere>=0.2.4,<0.3.0",
3939
"llama-hub>=0.0.79,<0.1.0",
4040
"llama-index>=0.10.40,<0.11.0",
41+
"chromadb<=0.5.16",
4142
"llama-index-vector-stores-chroma>=0.1.9",
4243
"llama-index-vector-stores-lancedb",
4344
"openai>=1.23.6,<2",
@@ -52,7 +53,8 @@ dependencies = [
5253
"python-dotenv>=1.0.1,<1.1",
5354
"tenacity>=8.2.3,<8.3",
5455
"theflow>=0.8.6,<0.9.0",
55-
"trogon>=0.5.0,<0.6"
56+
"trogon>=0.5.0,<0.6",
57+
"umap-learn==0.5.5",
5658
]
5759
readme = "README.md"
5860
authors = [
@@ -71,6 +73,7 @@ adv = [
7173
"duckduckgo-search>=6.1.0,<6.2",
7274
"elasticsearch>=8.13.0,<8.14",
7375
"fastembed",
76+
"onnxruntime<v1.20",
7477
"googlesearch-python>=1.2.4,<1.3",
7578
"llama-cpp-python<0.2.8",
7679
"llama-index>=0.10.40,<0.11.0",

libs/ktem/ktem/reasoning/simple.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import tiktoken
11+
from ktem.embeddings.manager import embedding_models_manager as embeddings
1112
from ktem.llms.manager import llms
1213
from ktem.reasoning.prompt_optimization import (
1314
CreateMindmapPipeline,
@@ -16,6 +17,8 @@
1617
)
1718
from ktem.utils.plantuml import PlantUML
1819
from ktem.utils.render import Render
20+
from ktem.utils.visualize_cited import CreateCitationVizPipeline
21+
from plotly.io import to_json
1922
from theflow.settings import settings as flowsettings
2023

2124
from kotaemon.base import (
@@ -240,6 +243,7 @@ class AnswerWithContextPipeline(BaseComponent):
240243

241244
enable_citation: bool = False
242245
enable_mindmap: bool = False
246+
enable_citation_viz: bool = False
243247

244248
system_prompt: str = ""
245249
lang: str = "English" # support English and Japanese
@@ -409,7 +413,12 @@ def mindmap_call():
409413

410414
answer = Document(
411415
text=output,
412-
metadata={"mindmap": mindmap, "citation": citation, "qa_score": qa_score},
416+
metadata={
417+
"citation_viz": self.enable_citation_viz,
418+
"mindmap": mindmap,
419+
"citation": citation,
420+
"qa_score": qa_score,
421+
},
413422
)
414423

415424
return answer
@@ -474,6 +483,11 @@ class Config:
474483
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
475484
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
476485
rewrite_pipeline: RewriteQuestionPipeline | None = None
486+
create_citation_viz_pipeline: CreateCitationVizPipeline = Node(
487+
default_callback=lambda _: CreateCitationVizPipeline(
488+
embedding=embeddings.get_default()
489+
)
490+
)
477491
add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx()
478492

479493
def retrieve(
@@ -641,10 +655,28 @@ def prepare_mindmap(self, answer) -> Document | None:
641655

642656
return mindmap_content
643657

644-
def show_citations_and_addons(self, answer, docs):
658+
def prepare_citation_viz(self, answer, question, docs) -> Document | None:
659+
doc_texts = [doc.text for doc in docs]
660+
citation_plot = None
661+
plot_content = None
662+
663+
if answer.metadata["citation_viz"] and len(docs) > 1:
664+
try:
665+
citation_plot = self.create_citation_viz_pipeline(doc_texts, question)
666+
except Exception as e:
667+
print("Failed to create citation plot:", e)
668+
669+
if citation_plot:
670+
plot = to_json(citation_plot)
671+
plot_content = Document(channel="plot", content=plot)
672+
673+
return plot_content
674+
675+
def show_citations_and_addons(self, answer, docs, question):
645676
# show the evidence
646677
with_citation, without_citation = self.prepare_citations(answer, docs)
647678
mindmap_output = self.prepare_mindmap(answer)
679+
citation_plot_output = self.prepare_citation_viz(answer, question, docs)
648680

649681
if not with_citation and not without_citation:
650682
yield Document(channel="info", content="<h5><b>No evidence found.</b></h5>")
@@ -661,6 +693,10 @@ def show_citations_and_addons(self, answer, docs):
661693
if mindmap_output:
662694
yield mindmap_output
663695

696+
# yield citation plot output
697+
if citation_plot_output:
698+
yield citation_plot_output
699+
664700
# yield warning message
665701
if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE:
666702
yield Document(
@@ -733,7 +769,7 @@ def generate_relevant_scores():
733769
if scoring_thread:
734770
scoring_thread.join()
735771

736-
yield from self.show_citations_and_addons(answer, docs)
772+
yield from self.show_citations_and_addons(answer, docs, message)
737773

738774
return answer
739775

@@ -767,6 +803,7 @@ def get_pipeline(cls, settings, states, retrievers):
767803
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
768804
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
769805
answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"]
806+
answer_pipeline.enable_citation_viz = settings[f"{prefix}.create_citation_viz"]
770807
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
771808
answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"]
772809
answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
@@ -820,6 +857,11 @@ def get_user_settings(cls) -> dict:
820857
"value": False,
821858
"component": "checkbox",
822859
},
860+
"create_citation_viz": {
861+
"name": "Create Embeddings Visualization",
862+
"value": False,
863+
"component": "checkbox",
864+
},
823865
"system_prompt": {
824866
"name": "System Prompt",
825867
"value": "This is a question answering system",
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
This module aims to project high-dimensional embeddings
3+
into a lower-dimensional space for visualization.
4+
5+
Refs:
6+
1. [RAGxplorer](https://github.com/gabrielchua/RAGxplorer)
7+
2. [RAGVizExpander](https://github.com/KKenny0/RAGVizExpander)
8+
"""
9+
from typing import List, Tuple
10+
11+
import numpy as np
12+
import pandas as pd
13+
import plotly.graph_objs as go
14+
import umap
15+
16+
from kotaemon.base import BaseComponent
17+
from kotaemon.embeddings import BaseEmbeddings
18+
19+
VISUALIZATION_SETTINGS = {
20+
"Original Query": {"color": "red", "opacity": 1, "symbol": "cross", "size": 15},
21+
"Retrieved": {"color": "green", "opacity": 1, "symbol": "circle", "size": 10},
22+
"Chunks": {"color": "blue", "opacity": 0.4, "symbol": "circle", "size": 10},
23+
"Sub-Questions": {"color": "purple", "opacity": 1, "symbol": "star", "size": 15},
24+
}
25+
26+
27+
class CreateCitationVizPipeline(BaseComponent):
28+
"""Creating PlotData for visualizing query results"""
29+
30+
embedding: BaseEmbeddings
31+
projector: umap.UMAP = None
32+
33+
def _set_up_umap(self, embeddings: np.ndarray):
34+
umap_transform = umap.UMAP().fit(embeddings)
35+
return umap_transform
36+
37+
def _project_embeddings(self, embeddings, umap_transform) -> np.ndarray:
38+
umap_embeddings = np.empty((len(embeddings), 2))
39+
for i, embedding in enumerate(embeddings):
40+
umap_embeddings[i] = umap_transform.transform([embedding])
41+
return umap_embeddings
42+
43+
def _get_projections(self, embeddings, umap_transform):
44+
projections = self._project_embeddings(embeddings, umap_transform)
45+
x = projections[:, 0]
46+
y = projections[:, 1]
47+
return x, y
48+
49+
def _prepare_projection_df(
50+
self,
51+
document_projections: Tuple[np.ndarray, np.ndarray],
52+
document_text: List[str],
53+
plot_size: int = 3,
54+
) -> pd.DataFrame:
55+
"""Prepares a DataFrame for visualization from projections and texts.
56+
57+
Args:
58+
document_projections (Tuple[np.ndarray, np.ndarray]):
59+
Tuple of X and Y coordinates of document projections.
60+
document_text (List[str]): List of document texts.
61+
"""
62+
df = pd.DataFrame({"x": document_projections[0], "y": document_projections[1]})
63+
df["document"] = document_text
64+
df["document_cleaned"] = df.document.str.wrap(50).apply(
65+
lambda x: x.replace("\n", "<br>")[:512] + "..."
66+
)
67+
df["size"] = plot_size
68+
df["category"] = "Retrieved"
69+
return df
70+
71+
def _plot_embeddings(self, df: pd.DataFrame) -> go.Figure:
72+
"""
73+
Creates a Plotly figure to visualize the embeddings.
74+
75+
Args:
76+
df (pd.DataFrame): DataFrame containing the data to visualize.
77+
78+
Returns:
79+
go.Figure: A Plotly figure object for visualization.
80+
"""
81+
fig = go.Figure()
82+
83+
for category in df["category"].unique():
84+
category_df = df[df["category"] == category]
85+
settings = VISUALIZATION_SETTINGS.get(
86+
category,
87+
{"color": "grey", "opacity": 1, "symbol": "circle", "size": 10},
88+
)
89+
fig.add_trace(
90+
go.Scatter(
91+
x=category_df["x"],
92+
y=category_df["y"],
93+
mode="markers",
94+
name=category,
95+
marker=dict(
96+
color=settings["color"],
97+
opacity=settings["opacity"],
98+
symbol=settings["symbol"],
99+
size=settings["size"],
100+
line_width=0,
101+
),
102+
hoverinfo="text",
103+
text=category_df["document_cleaned"],
104+
)
105+
)
106+
107+
fig.update_layout(
108+
height=500,
109+
legend=dict(y=100, x=0.5, xanchor="center", yanchor="top", orientation="h"),
110+
)
111+
return fig
112+
113+
def run(self, context: List[str], question: str):
114+
embed_contexts = self.embedding(context)
115+
context_embeddings = np.array([d.embedding for d in embed_contexts])
116+
117+
self.projector = self._set_up_umap(embeddings=context_embeddings)
118+
119+
embed_query = self.embedding(question)
120+
query_projection = self._get_projections(
121+
embeddings=[embed_query[0].embedding], umap_transform=self.projector
122+
)
123+
viz_query_df = pd.DataFrame(
124+
{
125+
"x": [query_projection[0][0]],
126+
"y": [query_projection[1][0]],
127+
"document_cleaned": question,
128+
"category": "Original Query",
129+
"size": 5,
130+
}
131+
)
132+
133+
context_projections = self._get_projections(
134+
embeddings=context_embeddings, umap_transform=self.projector
135+
)
136+
viz_base_df = self._prepare_projection_df(
137+
document_projections=context_projections, document_text=context
138+
)
139+
140+
visualization_df = pd.concat([viz_base_df, viz_query_df], axis=0)
141+
fig = self._plot_embeddings(visualization_df)
142+
return fig

0 commit comments

Comments
 (0)