Skip to content

Commit

Permalink
feat: working with assistant, UI, and others
Browse files Browse the repository at this point in the history
  • Loading branch information
iisakkirotko committed Nov 8, 2023
1 parent 1c7799f commit f228ce0
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 62 deletions.
1 change: 1 addition & 0 deletions icons/explore.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
207 changes: 145 additions & 62 deletions wanderlust.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import os

import ipyleaflet
import openai
from openai import OpenAI, NotFoundError
from openai.types.beta import Thread
from openai.types.beta.threads import Run

import time

import solara

Expand All @@ -17,11 +21,11 @@
markers = solara.reactive([])

url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
openai.api_key = os.getenv("OPENAI_API_KEY")
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model = "gpt-4-1106-preview"


function_descriptions = [
tools = [
{
"type": "function",
"function": {
Expand Down Expand Up @@ -94,17 +98,15 @@ def add_marker(longitude, latitude, label):


def ai_call(tool_call):
function = tool_call["function"]
name = function["name"]
arguments = json.loads(function["arguments"])
function = tool_call.function
name = function.name
arguments = json.loads(function.arguments)
return_value = functions[name](**arguments)
message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"name": tool_call["function"]["name"],
"content": return_value,
tool_outputs = {
"tool_call_id": tool_call.id,
"output": return_value,
}
return message
return tool_outputs


@solara.component
Expand All @@ -129,39 +131,66 @@ def Map():
@solara.component
def ChatInterface():
prompt = solara.use_reactive("")
run_id: solara.Reactive[str] = solara.use_reactive(None)

thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
print("thread id:", thread.id)

def add_message(value: str):
if value == "":
return
messages.set(messages.value + [{"role": "user", "content": value}])
prompt.set("")

def ask():
if not messages.value:
new_message = openai.beta.threads.messages.create(
thread_id=thread.id, content=value, role="user"
)
messages.set([*messages.value, new_message])
run_id.value = openai.beta.threads.runs.create(
thread_id=thread.id,
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
tools=tools,
).id
print("Run id:", run_id.value)

def poll():
if not run_id.value:
return
last_message = messages.value[-1]
if last_message["role"] == "user" or last_message["role"] == "tool":
completion = openai.ChatCompletion.create(
model=model,
messages=messages.value,
# Add function calling
tools=function_descriptions,
tool_choice="auto",
)

output = completion.choices[0].message
print("received", output)
completed = False
while not completed:
try:
handled_messages = handle_message(output)
messages.value = [*messages.value, output, *handled_messages]

except Exception as e:
print("errr", e)
run = openai.beta.threads.runs.retrieve(
run_id.value, thread_id=thread.id
) # When run is complete
print("run", run.status)
except NotFoundError:
print("run not found (Yet)")
continue
if run.status == "requires_action":
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
tool_output = ai_call(tool_call)
openai.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run_id.value,
tool_outputs=[tool_output],
)
if run.status == "completed":
messages.set(
[
*messages.value,
openai.beta.threads.messages.list(thread.id).data[0],
]
)
run_id.set(None)
completed = True
time.sleep(0.1)
retrieved_messages = openai.beta.threads.messages.list(thread_id=thread.id)
messages.set(retrieved_messages.data)

result = solara.use_thread(poll, dependencies=[run_id.value])

def handle_message(message):
print("handle", message)
messages = []
if message["role"] == "assistant":
if message.role == "assistant":
tools_calls = message.get("tool_calls", [])
for tool_call in tools_calls:
messages.append(ai_call(tool_call))
Expand All @@ -173,38 +202,71 @@ def handle_initial():
handle_message(message)

solara.use_effect(handle_initial, [])
result = solara.use_thread(ask, dependencies=[messages.value])
# result = solara.use_thread(ask, dependencies=[messages.value])
with solara.Column(
style={"height": "100%", "width": "38vw", "justify-content": "center"},
style={
"height": "100%",
"width": "38vw",
"justify-content": "center",
"background": "linear-gradient(0deg, transparent 75%, white 100%);",
},
classes=["chat-interface"],
):
if len(messages.value) > 0:
with solara.Column(style={"flex-grow": "1", "overflow-y": "auto"}):
for message in messages.value:
if message["role"] == "user":
solara.Text(
message["content"], classes=["chat-message", "user-message"]
)
elif message["role"] == "assistant":
if message["content"]:
solara.Markdown(message["content"])
elif message["tool_calls"]:
solara.Markdown("*Calling map functions*")
# The height works effectively as `min-height`, since flex will grow the container to fill the available space
with solara.Column(
style={
"flex-grow": "1",
"overflow-y": "auto",
"height": "100px",
"flex-direction": "column-reverse",
}
):
for message in reversed(messages.value):
with solara.Row(style={"align-items": "flex-start"}):
if message.role == "user":
solara.Text(
message.content[0].text.value,
classes=["chat-message", "user-message"],
)
assert len(message.content) == 1
elif message.role == "assistant":
if message.content[0].text.value:
solara.v.Icon(
children=["mdi-compass-outline"],
style_="padding-top: 10px;",
)
solara.Markdown(message.content[0].text.value)
elif message.content.tool_calls:
solara.v.Icon(
children=["mdi-map"],
style_="padding-top: 10px;",
)
solara.Markdown("*Calling map functions*")
else:
solara.v.Icon(
children=["mdi-compass-outline"],
style_="padding-top: 10px;",
)
solara.Preformatted(
repr(message),
classes=["chat-message", "assistant-message"],
)
elif message["role"] == "tool":
pass # no need to display
else:
solara.v.Icon(
children=["mdi-compass-outline"],
style_="padding-top: 10px;",
)
solara.Preformatted(
repr(message),
classes=["chat-message", "assistant-message"],
)
elif message["role"] == "tool":
pass # no need to display
else:
solara.Preformatted(
repr(message), classes=["chat-message", "assistant-message"]
)
# solara.Text(message, classes=["chat-message"])
# solara.Text(message, classes=["chat-message"])
with solara.Column():
solara.InputText(
label="Ask your ",
label="Ask your question here",
value=prompt,
style={"flex-grow": "1"},
on_value=add_message,
Expand Down Expand Up @@ -234,26 +296,47 @@ def load():
messages.set(json.load(f))
reset_ui()

with solara.Column(style={"flex-grow": "1"}, gap=0):
with solara.AppBar():
solara.Button("Save", on_click=save)
solara.Button("Load", on_click=load)
solara.Button("Soft reset", on_click=reset_ui)
with solara.Row(style={"height": "100%"}, justify="space-between"):
with solara.Column(
style={
"height": "95vh",
"justify-content": "center",
"padding": "45px 50px 75px 50px",
},
gap="5vh",
):
with solara.Row(justify="space-between"):
with solara.Row(gap="10px", style={"align-items": "center"}):
solara.v.Icon(children=["mdi-compass-rose"], size="36px")
solara.HTML(
tag="h2",
unsafe_innerHTML="Wanderlust",
style={"display": "inline-block"},
)
# with solara.Row(gap="10px"):
# solara.Button("Save", on_click=save)
# solara.Button("Load", on_click=load)
# solara.Button("Soft reset", on_click=reset_ui)
with solara.Row(justify="space-between", style={"flex-grow": "1"}):
ChatInterface().key(f"chat-{reset_counter}")
with solara.Column(style={"width": "58vw", "justify-content": "center"}):
with solara.Column(style={"width": "50vw", "justify-content": "center"}):
Map() # .key(f"map-{reset_counter}")

solara.Style(
"""
.jupyter-widgets.leaflet-widgets{
height: 100%;
border-radius: 20px;
}
.solara-autorouter-content{
display: flex;
flex-direction: column;
justify-content: stretch;
}
.v-toolbar__title{
display: flex;
align-items: center;
column-gap: 0.5rem;
}
"""
)

Expand Down

0 comments on commit f228ce0

Please sign in to comment.