Skip to content

Commit 94088cf

Browse files
committed
Added media enrichment to embedding process, with a test
1 parent 0da9a63 commit 94088cf

File tree

4 files changed

+739
-6
lines changed

4 files changed

+739
-6
lines changed

src/paperqa/docs.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
import logging
56
import os
@@ -481,7 +482,16 @@ async def aadd_texts(
481482
if embedding_model and texts[0].embedding is None:
482483
for t, t_embedding in zip(
483484
texts,
484-
await embedding_model.embed_documents(texts=[t.text for t in texts]),
485+
await embedding_model.embed_documents(
486+
texts=await asyncio.gather(
487+
*(
488+
t.get_embeddable_text(
489+
all_settings.parsing.should_parse_and_enrich_media[1]
490+
)
491+
for t in texts
492+
)
493+
)
494+
),
485495
strict=True,
486496
):
487497
t.embedding = t_embedding
@@ -535,14 +545,20 @@ def delete(
535545
self.deleted_dockeys.add(dockey)
536546
self.texts = list(filter(lambda x: x.doc.dockey != dockey, self.texts))
537547

538-
async def _build_texts_index(self, embedding_model: EmbeddingModel) -> None:
548+
async def _build_texts_index(
549+
self, embedding_model: EmbeddingModel, with_enrichment: bool = False
550+
) -> None:
539551
texts = [t for t in self.texts if t not in self.texts_index]
540552
# For any embeddings we are supposed to lazily embed, embed them now
541553
to_embed = [t for t in texts if t.embedding is None]
542554
if to_embed:
543555
for t, t_embedding in zip(
544556
to_embed,
545-
await embedding_model.embed_documents(texts=[t.text for t in to_embed]),
557+
await embedding_model.embed_documents(
558+
texts=await asyncio.gather(
559+
*(t.get_embeddable_text(with_enrichment) for t in to_embed)
560+
)
561+
),
546562
strict=True,
547563
):
548564
t.embedding = t_embedding
@@ -564,7 +580,10 @@ async def retrieve_texts(
564580
# TODO: should probably happen elsewhere
565581
self.texts_index.mmr_lambda = settings.texts_index_mmr_lambda
566582

567-
await self._build_texts_index(embedding_model)
583+
await self._build_texts_index(
584+
embedding_model,
585+
with_enrichment=settings.parsing.should_parse_and_enrich_media[1],
586+
)
568587
_k = k + len(self.deleted_dockeys)
569588
matches: list[Text] = cast(
570589
"list[Text]",

src/paperqa/types.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,33 @@ def __eq__(self, other) -> bool:
173173
def __hash__(self) -> int:
174174
return hash((self.name, self.text))
175175

176+
async def get_embeddable_text(self, with_enrichment: bool = False) -> str:
177+
"""Get the text to embed, which may be different from the actual text content.
178+
179+
This method is async so subclassers could use custom enrichment logic here.
180+
181+
Args:
182+
with_enrichment: Opt-in flag to include media enrichment in the return.
183+
Media enrichment can improve placement in embedding space,
184+
without affecting the text used for quotation.
185+
186+
Returns:
187+
Content to embed.
188+
"""
189+
if not with_enrichment:
190+
return self.text
191+
# Media enrichment can improve placement in embedding space,
192+
# without affecting the text used for quotation
193+
enriched_media = (
194+
(
195+
f"Media {m.index} from page {m.info.get('page_num', 'unknown')!s}'s"
196+
f" enriched description:\n\n{m.info['enriched_description']!s}"
197+
)
198+
for m in self.media
199+
if m.info.get("enriched_description")
200+
)
201+
return "\n\n".join((self.text, *enriched_media))
202+
176203

177204
# Sentinel to autopopulate a field within model_validator
178205
AUTOPOPULATE_VALUE = "" # NOTE: this is falsy by design

0 commit comments

Comments
 (0)