-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloader.py
135 lines (105 loc) · 4.65 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import Any
from pydantic import BaseModel
from unstructured.partition.pdf import partition_pdf
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain.vectorstores import Chroma
from langchain_core.documents import Document
import uuid
from config import *
class Element(BaseModel):
type: str
text: Any
def loader_bill(filename):
parent_directory = os.path.dirname(filename)
# Define the path for the image output directory
image_output_dir_path = os.path.join(parent_directory, "image")
if not os.path.exists(image_output_dir_path):
os.makedirs(image_output_dir_path)
raw_pdf_elements = partition_pdf(
filename=filename,
# Using pdf format to find embedded image blocks
extract_images_in_pdf=True,
# Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
# Titles are any sub-section of the document
infer_table_structure=True,
# Post processing to aggregate text once we have the title
chunking_strategy="by_title",
# Chunking params to aggregate text blocks
# Attempt to create a new chunk 3800 chars
# Attempt to keep chunks > 2000 chars
# Hard max on chunks
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
image_output_dir_path=image_output_dir_path,
)
# Create a dictionary to store counts of each type
category_counts = {}
for element in raw_pdf_elements:
category = str(type(element))
if category in category_counts:
category_counts[category] += 1
else:
category_counts[category] = 1
# Unique_categories will have unique elements
# TableChunk if Table > max chars set above
unique_categories = set(category_counts.keys())
category_counts
# Categorize by type
categorized_elements = []
for element in raw_pdf_elements:
if "unstructured.documents.elements.Table" in str(type(element)):
categorized_elements.append(Element(type="table", text=str(element)))
elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
categorized_elements.append(Element(type="text", text=str(element)))
# Tables
table_elements = [e for e in categorized_elements if e.type == "table"]
print(len(table_elements))
# Text
text_elements = [e for e in categorized_elements if e.type == "text"]
print(len(text_elements))
prompt_text = """You are an assistant tasked with summarizing tables and text. \
Give a concise summary of the table or text. Table or text chunk: {element} """
prompt = ChatPromptTemplate.from_template(prompt_text)
# Summary chain
model = ChatOpenAI(temperature=0, model="gpt-4")
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
# Apply to text
texts = [i.text for i in text_elements]
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
# Apply to tables
tables = [i.text for i in table_elements]
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
# The vectorstore to use to index the child chunks
vectorstore = Chroma(collection_name="summaries", embedding_function=OpenAIEmbeddings())
# The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"
# The retriever (empty to start)
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
Document(page_content=s, metadata={id_key: doc_id}) # Ensure id_key is in metadata
for s, doc_id in zip(text_summaries, doc_ids)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))
# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
Document(page_content=s, metadata={id_key: table_id}) # Ensure id_key is in metadata
for s, table_id in zip(table_summaries, table_ids)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))
return retriever