Skip to content

Commit

Permalink
feat: real time progress feedback
Browse files Browse the repository at this point in the history
When stdout.flush() or stderr.flush() is called, the current output will be displayed in the message that is being evaluated. The message with be updated at most once per 3 seconds.
  • Loading branch information
vanutp committed Apr 20, 2024
1 parent 547c1c6 commit 8e85d7a
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 20 deletions.
16 changes: 11 additions & 5 deletions tgpy/_core/message_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tgpy import app

TITLE = 'TGPy>'
RUNNING_TITLE = 'TGPy running>'
OLD_TITLE_URL = 'https://github.com/tm-a-t/TGPy'
TITLE_URL = 'https://tgpy.tmat.me/'
FORMATTED_ERROR_HEADER = f'<b><a href="{TITLE_URL}">TGPy error&gt;</a></b>'
Expand Down Expand Up @@ -40,9 +41,10 @@ def __getitem__(self, item):
async def edit_message(
message: Message,
code: str,
result: str,
result: str = '',
traceback: str = '',
output: str = '',
is_running: bool = False,
) -> Message:
if not result and output:
result = output
Expand All @@ -51,7 +53,10 @@ async def edit_message(
result = traceback
traceback = ''

title = Utf16CodepointsWrapper(TITLE)
if is_running:
title = Utf16CodepointsWrapper(RUNNING_TITLE)
else:
title = Utf16CodepointsWrapper(TITLE)
parts = [
Utf16CodepointsWrapper(code.strip()),
Utf16CodepointsWrapper(f'{title} {str(result).strip()}'),
Expand All @@ -64,9 +69,10 @@ async def edit_message(

entities = []
offset = 0
for p in parts:
for i, p in enumerate(parts):
entities.append(MessageEntityCode(offset, len(p)))
offset += len(p) + 2
newline_cnt = 1 if i == 1 else 2
offset += len(p) + newline_cnt

entities[0] = MessageEntityPre(entities[0].offset, entities[0].length, 'python')
entities[1].offset += len(title) + 1
Expand All @@ -83,7 +89,7 @@ async def edit_message(
),
]

text = str('\n\n'.join(parts))
text = str(''.join(x + ('\n' if i == 1 else '\n\n') for i, x in enumerate(parts)))
if len(text) > 4096:
text = text[:4095] + '…'
return await message.edit(text, formatting_entities=entities, link_preview=False)
Expand Down
69 changes: 60 additions & 9 deletions tgpy/api/tgpy_eval.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import time
from dataclasses import dataclass
from typing import Any

from telethon.tl.custom import Message

import tgpy.api
from tgpy import app
from tgpy._core import message_design
from tgpy._core.meval import _meval
from tgpy.api.parse_code import parse_code
from tgpy.utils import FILENAME_PREFIX, numid
Expand All @@ -19,6 +22,46 @@ class EvalResult:
output: str


class Flusher:
_code: str
_message: Message | None
_flushed_output: str
_flush_timer: asyncio.Task | None
_finished: bool

def __init__(self, code: str, message: Message | None):
self._code = code
self._message = message
self._flushed_output = ''
self._flush_timer = None
self._finished = False

async def _wait_and_flush(self):
await asyncio.sleep(3)
await message_design.edit_message(
self._message,
self._code,
output=self._flushed_output,
is_running=True,
)
self._flush_timer = None

def flush_handler(self):
if not self._message or self._finished:
return
# noinspection PyProtectedMember
self._flushed_output = app.ctx._output
if self._flush_timer:
# flush already scheduled, will print the latest output
return
self._flush_timer = asyncio.create_task(self._wait_and_flush())

def set_finished(self):
if self._flush_timer:
self._flush_timer.cancel()
self._finished = True


async def tgpy_eval(
code: str,
message: Message | None = None,
Expand All @@ -32,8 +75,13 @@ async def tgpy_eval(
else:
raise ValueError('Invalid code provided')

if message:
await message_design.edit_message(message, code, is_running=True)

flusher = Flusher(code, message)

# noinspection PyProtectedMember
app.ctx._init_stdout()
app.ctx._init_stdio(flusher.flush_handler)
kwargs = {'msg': message}
if message:
# noinspection PyProtectedMember
Expand All @@ -50,13 +98,16 @@ async def tgpy_eval(
else:
kwargs['orig'] = None

new_variables, result = await _meval(
parsed,
filename,
tgpy.api.variables,
**tgpy.api.constants,
**kwargs,
)
try:
new_variables, result = await _meval(
parsed,
filename,
tgpy.api.variables,
**tgpy.api.constants,
**kwargs,
)
finally:
flusher.set_finished()
if '__all__' in new_variables:
new_variables = {
k: v for k, v in new_variables.items() if k in new_variables['__all__']
Expand All @@ -66,7 +117,7 @@ async def tgpy_eval(
# noinspection PyProtectedMember
return EvalResult(
result=result,
output=app.ctx._stdout,
output=app.ctx._output,
)


Expand Down
33 changes: 27 additions & 6 deletions tgpy/context.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,46 @@
import sys
from contextvars import ContextVar
from io import StringIO, TextIOBase
from typing import Callable

from telethon.tl.custom import Message

_is_module: ContextVar[bool] = ContextVar('_is_module')
_message: ContextVar[Message] = ContextVar('_message')
_stdout: ContextVar[StringIO] = ContextVar('_stdout')
_stderr: ContextVar[StringIO] = ContextVar('_stderr')
_flush_handler: ContextVar[Callable[[], None]] = ContextVar('_flush_handler')
_is_manual_output: ContextVar[bool] = ContextVar('_is_manual_output', default=False)


class _StdoutWrapper(TextIOBase):
def __init__(self, contextvar, fallback):
self.__contextvar = contextvar
self.__fallback = fallback

def __getobj(self):
return _stdout.get(sys.__stdout__)
return self.__contextvar.get(self.__fallback)

def write(self, s: str) -> int:
return self.__getobj().write(s)

def flush(self) -> None:
return self.__getobj().flush()
self.__getobj().flush()
if flush_handler := _flush_handler.get(None):
flush_handler()

@property
def isatty(self):
return getattr(self.__getobj(), 'isatty', None)


sys.stdout = _StdoutWrapper()
sys.stdout = _StdoutWrapper(_stdout, sys.__stdout__)
sys.stderr = _StdoutWrapper(_stderr, sys.__stderr__)


def cleanup_erases(data: str):
lines = data.replace('\r\n', '\n').split('\n')
return '\n'.join(x.rsplit('\r', 1)[-1] for x in lines)


class Context:
Expand All @@ -46,12 +61,18 @@ def _set_msg(msg: Message):
_message.set(msg)

@staticmethod
def _init_stdout():
def _init_stdio(flush_handler: Callable[[], None]):
_stdout.set(StringIO())
_stderr.set(StringIO())
_flush_handler.set(flush_handler)

@property
def _stdout(self) -> str:
return _stdout.get().getvalue()
def _output(self) -> str:
stderr = cleanup_erases(_stderr.get().getvalue())
stdout = cleanup_erases(_stdout.get().getvalue())
if stderr and stderr[-1] != '\n':
stderr += '\n'
return stderr + stdout

@property
def is_manual_output(self):
Expand Down

0 comments on commit 8e85d7a

Please sign in to comment.