-
Notifications
You must be signed in to change notification settings - Fork 30
/
docv.py
166 lines (129 loc) · 4.89 KB
/
docv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# from https://docs.streamlit.io/develop/tutorials/llms/build-conversational-apps
import streamlit as st
from langchain_upstage import ChatUpstage as Chat
from langchain_upstage import UpstageLayoutAnalysisLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import AIMessage, HumanMessage
from streamlit_paste_button import paste_image_button as pbutton
import base64
import io
import tempfile
DOCV_MODEL_NAME = st.secrets["DOCV_MODEL_NAME"]
docv = Chat(model=DOCV_MODEL_NAME)
MODEL_NAME = st.secrets["MODEL_NAME"]
solar_pro = Chat(model=MODEL_NAME)
chat_with_history_prompt = ChatPromptTemplate.from_template(
"""
You are a helpful assistant.
Answer the following questions considering the history of the conversation.
Think step by step and to explain your thought process as you answer.
----
Chat history: {chat_history}
----
Image context in HTML from OCR: {image_context}
----
User question: {user_query}
"""
)
def get_img_context(img_bytes):
image_context = ""
if img_bytes:
with tempfile.NamedTemporaryFile(delete=True) as f:
f.write(img_bytes)
image_path = f.name
layzer = UpstageLayoutAnalysisLoader(image_path, split="page")
# For improved memory efficiency, consider using the lazy_load method to load documents page by page.
docs = layzer.load() # or layzer.lazy_load()
image_context = [doc.page_content for doc in docs]
return image_context
def get_solar_pro_response(user_query, chat_history, image_context: str = None):
chain = chat_with_history_prompt | solar_pro | StrOutputParser()
return chain.stream(
{
"chat_history": chat_history,
"image_context": image_context,
"user_query": user_query,
}
)
def write_docv_response_stream(human_message):
chain = docv | StrOutputParser()
response = st.write_stream(
chain.stream(st.session_state.messages + [human_message])
)
return response
def get_human_message(text_data, image_data=None):
if not image_data:
return HumanMessage(content=text_data)
return HumanMessage(
content=[
{"type": "text", "text": f"{text_data}"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
],
)
def get_human_message_img_url(text_data, image_url=None):
if not image_url:
return HumanMessage(content=text_data)
return HumanMessage(
content=[
{"type": "text", "text": f"{text_data}"},
{
"type": "image_url",
"image_url": {"url": f"{image_url}"},
},
],
)
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
role = "AI" if isinstance(message, AIMessage) else "Human"
with st.chat_message(role):
if len(message.content) == 2:
st.markdown(message.content[0]["text"])
else:
st.markdown(message.content)
img_file_buffer = st.file_uploader("Upload a image image", type=["png", "jpg", "jpeg"])
img_bytes = None
if img_file_buffer:
# reset history
st.session_state.messages = []
st.image(img_file_buffer)
img_bytes = img_file_buffer.read()
paste_result = pbutton("📋 Paste an image")
if paste_result.image_data is not None:
# reset history
st.session_state.messages = []
st.write("Pasted image:")
st.image(paste_result.image_data)
img_bytes = io.BytesIO()
paste_result.image_data.save(img_bytes, format="PNG")
img_bytes = img_bytes.getvalue() # Image as bytes
if prompt := st.chat_input("What is up?"):
human_message = get_human_message(prompt)
if img_bytes:
# remove the image from the buffer
for message in st.session_state.messages:
if isinstance(message, HumanMessage):
if len(message.content) == 2:
if message.content[1]["type"] == "image_url":
st.session_state.messages.remove(message)
break
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
human_message = get_human_message(prompt, img_base64)
img_file_buffer = None
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
st.markdown("**Model1:**")
response = write_docv_response_stream(human_message)
st.markdown("**Model2:**")
img_context = get_img_context(img_bytes)
st.json(img_context, expanded=False)
response2 = st.write_stream(
get_solar_pro_response(prompt, st.session_state.messages, img_context)
)
st.session_state.messages.append(human_message)
st.session_state.messages.append(AIMessage(content=response))