Skip to content

Commit

Permalink
Fix browser init (pytorch#797)
Browse files Browse the repository at this point in the history
Update prompt chat is waiting for, which was modified by pytorch/torchchat#476

Modify logging defaults to not create a file in a temp folder without prompting user, but rather just print an info messages Replace few `prints` with `logging.info` This way, information about bandwith achieved will be printed to the console, but not to the web-browser chat window

Test plan:
```
% python3 torchchat.py browser stories110M &
% curl -L http://127.0.0.1:5000
% curl -d "prompt=Once upon a time" -X POST http://127.0.0.1:5000/chat
```

TODOs: 
- Add CI that repeats above steps
-Figure out if spawning generator from the browser can be avoided

Fixes pytorch/torchchat#785
  • Loading branch information
malfet committed Jul 17, 2024
1 parent 594414f commit bc8126e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
4 changes: 3 additions & 1 deletion chat_in_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
except:
continue

if decoded.startswith("System Prompt") and decoded.endswith(": "):
if decoded.endswith("Do you want to enter a system prompt? Enter y for yes and anything else for no. \n"):
print(f"| {decoded}")
proc.stdin.write("\n".encode("utf-8"))
proc.stdin.flush()
Expand Down Expand Up @@ -93,6 +93,8 @@ def chat():
model_prefix = "Model: "
if output.startswith(model_prefix):
output = output[len(model_prefix) :]
else:
print("But output is", output)

global convo

Expand Down
5 changes: 1 addition & 4 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
from download import download_and_convert, is_model_downloaded

FORMAT = (
"%(levelname)s: %(asctime)-15s: %(filename)s: %(funcName)s: %(module)s: %(message)s"
)
logging.basicConfig(filename="/tmp/torchchat.log", level=logging.INFO, format=FORMAT)
logging.basicConfig(level=logging.INFO,format="%(message)s")
logger = logging.getLogger(__name__)

default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
Expand Down
6 changes: 3 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,12 +752,12 @@ def callback(x):
# Don't continue here.... because we need to report and reset
# continue

print(
logging.info(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
)
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
logging.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
if i == 0:
print(
logging.info(
f"*** This first iteration will include cold start effects for dynamic import, hardware caches{', JIT compilation' if jit_compile else ''}. ***"
)
if start_pos >= max_seq_length:
Expand Down

0 comments on commit bc8126e

Please sign in to comment.