Skip to content

Commit a956dc6

Browse files
authored
Merge pull request #16 from jepler/cancel
chap tui: add ability to cancel generation with escape key
2 parents f3bf17c + 9d03cd2 commit a956dc6

File tree

2 files changed

+62
-14
lines changed

2 files changed

+62
-14
lines changed

src/chap/commands/tui.css

+5-1
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@ Markdown {
3939
Footer {
4040
dock: top;
4141
}
42-
Input {
42+
43+
#inputbox {
4344
dock: bottom;
45+
height: 3;
4446
}
4547

48+
#inputbox Button { dock: right; display: none; }
49+
4650
Markdown {
4751
margin: 0 1 0 0;
4852
}

src/chap/commands/tui.py

+57-13
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import sys
88

99
from markdown_it import MarkdownIt
10+
from textual import work
1011
from textual.app import App
1112
from textual.binding import Binding
12-
from textual.containers import Container, VerticalScroll
13-
from textual.widgets import Footer, Input, Markdown
13+
from textual.containers import Container, Horizontal, VerticalScroll
14+
from textual.widgets import Button, Footer, Input, Markdown
1415

1516
from ..core import command_uses_new_session, get_api, new_session_path
1617
from ..session import Assistant, Session, User
@@ -41,10 +42,16 @@ def markdown_for_step(step):
4142
)
4243

4344

45+
class CancelButton(Button):
46+
BINDINGS = [
47+
Binding("escape", "stop_generating", "Stop Generating", show=True),
48+
]
49+
50+
4451
class Tui(App):
4552
CSS_PATH = "tui.css"
4653
BINDINGS = [
47-
Binding("ctrl+c", "app.quit", "Quit", show=True, priority=True),
54+
Binding("ctrl+q", "quit", "Quit", show=True, priority=True),
4855
]
4956

5057
def __init__(self, api=None, session=None):
@@ -56,29 +63,44 @@ def __init__(self, api=None, session=None):
5663
def input(self):
5764
return self.query_one(Input)
5865

66+
@property
67+
def cancel_button(self):
68+
return self.query_one(CancelButton)
69+
5970
@property
6071
def container(self):
6172
return self.query_one("#content")
6273

6374
def compose(self):
6475
yield Footer()
65-
yield Input(placeholder="Prompt")
66-
yield VerticalScroll(Container(id="pad"), id="content")
76+
yield VerticalScroll(
77+
*[markdown_for_step(step) for step in self.session.session],
78+
Container(id="pad"),
79+
id="content",
80+
)
81+
with Horizontal(id="inputbox"):
82+
yield CancelButton(label="❌", id="cancel")
83+
yield Input(placeholder="Prompt")
6784

6885
async def on_mount(self) -> None:
69-
await self.container.mount_all(
70-
[markdown_for_step(step) for step in self.session.session], before="#pad"
71-
)
72-
# self.scrollview.scroll_y = self.scrollview.get_content_height()
73-
self.scroll_end()
86+
self.container.scroll_end(animate=False)
7487
self.input.focus()
88+
self.cancel_button.disabled = True
89+
self.cancel_button.styles.display = "none"
7590

7691
async def on_input_submitted(self, event) -> None:
92+
self.get_completion(event.value)
93+
94+
@work(exclusive=True)
95+
async def get_completion(self, query):
7796
self.scroll_end()
7897
self.input.disabled = True
98+
self.cancel_button.disabled = False
99+
self.cancel_button.styles.display = "block"
100+
self.cancel_button.focus()
79101
output = markdown_for_step(Assistant("*query sent*"))
80102
await self.container.mount_all(
81-
[markdown_for_step(User(event.value)), output], before="#pad"
103+
[markdown_for_step(User(query)), output], before="#pad"
82104
)
83105
tokens = []
84106
update = asyncio.Queue(1)
@@ -89,6 +111,14 @@ async def on_input_submitted(self, event) -> None:
89111
if not wi.has_class("history_exclude"):
90112
session.session.append(si)
91113

114+
message = Assistant("")
115+
self.session.session.extend(
116+
[
117+
User(query),
118+
message,
119+
]
120+
)
121+
92122
async def render_fun():
93123
while await update.get():
94124
if tokens:
@@ -97,25 +127,29 @@ async def render_fun():
97127
await asyncio.sleep(0.1)
98128

99129
async def get_token_fun():
100-
async for token in self.api.aask(session, event.value):
130+
async for token in self.api.aask(session, query):
101131
tokens.append(token)
132+
message.content += token
102133
try:
103134
update.put_nowait(True)
104135
except asyncio.QueueFull:
136+
# QueueFull exception is expected. If something's in the
137+
# queue then render_fun will run soon.
105138
pass
106139
await update.put(False)
107140

108141
try:
109142
await asyncio.gather(render_fun(), get_token_fun())
110143
self.input.value = ""
111144
finally:
112-
self.session.session.extend(session.session[-2:])
113145
all_output = self.session.session[-1].content
114146
output.update(all_output)
115147
output._markdown = all_output # pylint: disable=protected-access
116148
self.container.scroll_end()
117149
self.input.disabled = False
118150
self.input.focus()
151+
self.cancel_button.disabled = True
152+
self.cancel_button.styles.display = "none"
119153

120154
def scroll_end(self):
121155
self.call_after_refresh(self.container.scroll_end)
@@ -139,6 +173,16 @@ def action_toggle_history(self):
139173
children[idx].toggle_class("history_exclude")
140174
children[idx + 1].toggle_class("history_exclude")
141175

176+
async def action_stop_generating(self):
177+
self.workers.cancel_all()
178+
179+
async def on_button_pressed(self, event): # pylint: disable=unused-argument
180+
self.workers.cancel_all()
181+
182+
async def action_quit(self):
183+
self.workers.cancel_all()
184+
self.exit()
185+
142186
async def action_resubmit(self):
143187
await self.delete_or_resubmit(True)
144188

0 commit comments

Comments
 (0)