Skip to content

Commit

Permalink
Update browser logic to match updated chat format (pytorch#478)
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer authored and malfet committed Jul 17, 2024
1 parent c5eb5b0 commit 6c2456b
Showing 1 changed file with 48 additions and 9 deletions.
57 changes: 48 additions & 9 deletions chat_in_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,37 @@ def create_app(*args):
["python3", "generate.py", *args], stdin=subprocess.PIPE, stdout=subprocess.PIPE
)


@app.route("/")
def main():
print("Starting chat session.")
line = b""
output = ""
global disable_input

while True:
line = proc.stdout.readline()
if line.decode("utf-8").startswith("What is your prompt?"):
buffer = proc.stdout.read(1)
line += buffer
try:
decoded = line.decode("utf-8")
except:
continue

if decoded.startswith("System Prompt") and decoded.endswith(": "):
print(f"| {decoded}")
proc.stdin.write("\n".encode("utf-8"))
proc.stdin.flush()
line = b""
elif line.decode("utf-8").startswith("User: "):
print(f"| {decoded}")
break
output += line.decode("utf-8").strip() + "\n"

if decoded.endswith("\r") or decoded.endswith("\n"):
decoded = decoded.strip()
print(f"| {decoded}")
output += decoded + "\n"
line = b""

return render_template(
"chat.html",
convo="Hello! What is your prompt?",
Expand All @@ -44,23 +65,41 @@ def chat():
proc.stdin.write((_prompt + "\n").encode("utf-8"))
proc.stdin.flush()

print(f"User: {_prompt}")

line = b""
output = ""
global disable_input

while True:
line = proc.stdout.readline()
if line.decode("utf-8").startswith("What is your prompt?"):
buffer = proc.stdout.read(1)
line += buffer
try:
decoded = line.decode("utf-8")
except:
continue

if decoded.startswith("User: "):
break
if line.decode("utf-8").startswith("=========="):
if decoded.startswith("=========="):
disable_input = True
break
output += line.decode("utf-8").strip() + "\n"
if decoded.endswith("\r") or decoded.endswith("\n"):
decoded = decoded.strip()
print(f"| {decoded}")
output += decoded + "\n"
line = b""

# Strip "Model: " from output
model_prefix = "Model: "
if output.startswith(model_prefix):
output = output[len(model_prefix):]

global convo

if _prompt:
convo += "<H1>Your prompt</H1>\n<p> " + _prompt + " </p>\n\n"
convo += "<H1>My response</H1>\n<p> " + output + " </p>\n\n"
convo += "<H1>User</H1>\n<p> " + _prompt + " </p>\n\n"
convo += "<H1>Model</H1>\n<p> " + output + " </p>\n\n"

return render_template("chat.html", convo=convo, disable_input=disable_input)

Expand Down

0 comments on commit 6c2456b

Please sign in to comment.