7
7
import sys
8
8
9
9
from markdown_it import MarkdownIt
10
+ from textual import work
10
11
from textual .app import App
11
12
from textual .binding import Binding
12
- from textual .containers import Container , VerticalScroll
13
- from textual .widgets import Footer , Input , Markdown
13
+ from textual .containers import Container , Horizontal , VerticalScroll
14
+ from textual .widgets import Button , Footer , Input , Markdown
14
15
15
16
from ..core import command_uses_new_session , get_api , new_session_path
16
17
from ..session import Assistant , Session , User
@@ -41,10 +42,16 @@ def markdown_for_step(step):
41
42
)
42
43
43
44
45
+ class CancelButton (Button ):
46
+ BINDINGS = [
47
+ Binding ("escape" , "stop_generating" , "Stop Generating" , show = True ),
48
+ ]
49
+
50
+
44
51
class Tui (App ):
45
52
CSS_PATH = "tui.css"
46
53
BINDINGS = [
47
- Binding ("ctrl+c " , "app. quit" , "Quit" , show = True , priority = True ),
54
+ Binding ("ctrl+q " , "quit" , "Quit" , show = True , priority = True ),
48
55
]
49
56
50
57
def __init__ (self , api = None , session = None ):
@@ -56,29 +63,44 @@ def __init__(self, api=None, session=None):
56
63
def input (self ):
57
64
return self .query_one (Input )
58
65
66
+ @property
67
+ def cancel_button (self ):
68
+ return self .query_one (CancelButton )
69
+
59
70
@property
60
71
def container (self ):
61
72
return self .query_one ("#content" )
62
73
63
74
def compose (self ):
64
75
yield Footer ()
65
- yield Input (placeholder = "Prompt" )
66
- yield VerticalScroll (Container (id = "pad" ), id = "content" )
76
+ yield VerticalScroll (
77
+ * [markdown_for_step (step ) for step in self .session .session ],
78
+ Container (id = "pad" ),
79
+ id = "content" ,
80
+ )
81
+ with Horizontal (id = "inputbox" ):
82
+ yield CancelButton (label = "❌" , id = "cancel" )
83
+ yield Input (placeholder = "Prompt" )
67
84
68
85
async def on_mount (self ) -> None :
69
- await self .container .mount_all (
70
- [markdown_for_step (step ) for step in self .session .session ], before = "#pad"
71
- )
72
- # self.scrollview.scroll_y = self.scrollview.get_content_height()
73
- self .scroll_end ()
86
+ self .container .scroll_end (animate = False )
74
87
self .input .focus ()
88
+ self .cancel_button .disabled = True
89
+ self .cancel_button .styles .display = "none"
75
90
76
91
async def on_input_submitted (self , event ) -> None :
92
+ self .get_completion (event .value )
93
+
94
+ @work (exclusive = True )
95
+ async def get_completion (self , query ):
77
96
self .scroll_end ()
78
97
self .input .disabled = True
98
+ self .cancel_button .disabled = False
99
+ self .cancel_button .styles .display = "block"
100
+ self .cancel_button .focus ()
79
101
output = markdown_for_step (Assistant ("*query sent*" ))
80
102
await self .container .mount_all (
81
- [markdown_for_step (User (event . value )), output ], before = "#pad"
103
+ [markdown_for_step (User (query )), output ], before = "#pad"
82
104
)
83
105
tokens = []
84
106
update = asyncio .Queue (1 )
@@ -89,6 +111,14 @@ async def on_input_submitted(self, event) -> None:
89
111
if not wi .has_class ("history_exclude" ):
90
112
session .session .append (si )
91
113
114
+ message = Assistant ("" )
115
+ self .session .session .extend (
116
+ [
117
+ User (query ),
118
+ message ,
119
+ ]
120
+ )
121
+
92
122
async def render_fun ():
93
123
while await update .get ():
94
124
if tokens :
@@ -97,25 +127,29 @@ async def render_fun():
97
127
await asyncio .sleep (0.1 )
98
128
99
129
async def get_token_fun ():
100
- async for token in self .api .aask (session , event . value ):
130
+ async for token in self .api .aask (session , query ):
101
131
tokens .append (token )
132
+ message .content += token
102
133
try :
103
134
update .put_nowait (True )
104
135
except asyncio .QueueFull :
136
+ # QueueFull exception is expected. If something's in the
137
+ # queue then render_fun will run soon.
105
138
pass
106
139
await update .put (False )
107
140
108
141
try :
109
142
await asyncio .gather (render_fun (), get_token_fun ())
110
143
self .input .value = ""
111
144
finally :
112
- self .session .session .extend (session .session [- 2 :])
113
145
all_output = self .session .session [- 1 ].content
114
146
output .update (all_output )
115
147
output ._markdown = all_output # pylint: disable=protected-access
116
148
self .container .scroll_end ()
117
149
self .input .disabled = False
118
150
self .input .focus ()
151
+ self .cancel_button .disabled = True
152
+ self .cancel_button .styles .display = "none"
119
153
120
154
def scroll_end (self ):
121
155
self .call_after_refresh (self .container .scroll_end )
@@ -139,6 +173,16 @@ def action_toggle_history(self):
139
173
children [idx ].toggle_class ("history_exclude" )
140
174
children [idx + 1 ].toggle_class ("history_exclude" )
141
175
176
+ async def action_stop_generating (self ):
177
+ self .workers .cancel_all ()
178
+
179
+ async def on_button_pressed (self , event ): # pylint: disable=unused-argument
180
+ self .workers .cancel_all ()
181
+
182
+ async def action_quit (self ):
183
+ self .workers .cancel_all ()
184
+ self .exit ()
185
+
142
186
async def action_resubmit (self ):
143
187
await self .delete_or_resubmit (True )
144
188
0 commit comments