Skip to content

Commit

Permalink
fix: allow model name override via cli.
Browse files Browse the repository at this point in the history
  • Loading branch information
codito committed Oct 18, 2024
1 parent 76cd598 commit 19fcc74
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
6 changes: 5 additions & 1 deletion arey/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,15 @@ class Chat:
# return context_size - prompt_tokens_without_history - buffer


def create_chat() -> tuple[Chat, ModelMetrics]:
def create_chat(model_name: str | None) -> tuple[Chat, ModelMetrics]:
"""Create a new chat session."""
# FIXME
# system_prompt = prompt_model.get_message("system", "") if prompt_model else ""
global model
system_prompt = ""
if model_name is not None and model_name in config.models:
model = get_completion_llm(config.models[model_name])

with capture_stderr() as stderr:
model.load(system_prompt)
chat = Chat()
Expand Down
10 changes: 8 additions & 2 deletions arey/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,15 @@ def task(instruction: str, overrides_file: str, verbose: bool) -> int:


@main.command("chat")
@click.option(
"-m",
"--model",
default=None,
help="Model to use for chat, must be defined in the config.",
)
@error_handler
@common_options
def chat(verbose: bool) -> int:
def chat(verbose: bool, model: str | None) -> int:
"""Chat with an AI model."""
import readline # noqa enable GNU readline capabilities. # pyright: ignore[reportUnusedImport]
from arey.chat import create_chat, get_completion_metrics, stream_response
Expand All @@ -195,7 +201,7 @@ def chat(verbose: bool) -> int:
console.print()

with console.status("[message_footer]Loading model..."):
chat, model_metrics = create_chat()
chat, model_metrics = create_chat(model)
footer = f"✓ Model loaded. {model_metrics.init_latency_ms / 1000:.2f}s."
console.print(footer, style="message_footer")
console.print()
Expand Down

0 comments on commit 19fcc74

Please sign in to comment.