Skip to content

Commit

Permalink
better user messages, including current queue size
Browse files Browse the repository at this point in the history
  • Loading branch information
randommm committed Nov 19, 2023
1 parent 6387f32 commit fe3bcb4
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 42 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ hf-hub = "0.3.2"
tokenizers = "0.15"
thread-priority = "0.15"
regex = "1.5"
crossbeam-channel = "0.5"

[profile.dev.package."*"]
opt-level = 3
4 changes: 2 additions & 2 deletions src/routes/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::{LLMReceiver, LLMSender};
use candle::quantized::gguf_file;
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use std::sync::mpsc::channel;
use crossbeam_channel::unbounded;
use std::thread;

use candle_transformers::models::quantized_llama as model;
Expand Down Expand Up @@ -99,7 +99,7 @@ impl ModelBuilder {
let (tx, rx) = if let Some(llm_receiver) = self.llm_receiver {
(None, llm_receiver)
} else {
let (tx, rx) = channel();
let (tx, rx) = unbounded();
(Some(tx), rx)
};

Expand Down
2 changes: 1 addition & 1 deletion src/routes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use axum::{
routing::{get, post},
Router,
};
use crossbeam_channel::{Receiver, Sender};
use error_handling::AppError;
use llm::ModelBuilder;
use sqlx::SqlitePool;
use std::sync::mpsc::{Receiver, Sender};
use tokio::sync::oneshot::Sender as OneShotSender;

type LLMSender = Sender<(String, Vec<u32>, OneShotSender<(String, Vec<u32>)>)>;
Expand Down
74 changes: 35 additions & 39 deletions src/routes/pages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,44 +151,35 @@ async fn process_slack_events(
let reqw_client = reqwest::Client::new();

let reply_to_user = if text == "delete" || text == "\"delete\"" {
let _ = sqlx::query(
"DELETE FROM sessions
WHERE channel = $1",
)
.bind(channel)
.execute(&db_pool)
.await;
let _ = sqlx::query("DELETE FROM sessions WHERE channel = $1")
.bind(channel)
.execute(&db_pool)
.await;

"Ok, the LLM section was deleted. A new message will start a fresh LLM section.".to_owned()
} else if text == "plot" || text == "\"plot\"" {
return plot_random_stuff(channel.to_owned(), slack_oauth_token.clone()).await;
} else {
let mut initial_message = "Running the LLM ".to_owned();
// select that checks if a state exists
let query: Result<(Vec<u8>,), _> = sqlx::query_as(
r#"SELECT model_state
FROM sessions WHERE channel = $1;"#,
)
.bind(channel)
.fetch_one(&db_pool)
.await;
let query: Result<(Vec<u8>,), _> =
sqlx::query_as(r#"SELECT model_state FROM sessions WHERE channel = $1;"#)
.bind(channel)
.fetch_one(&db_pool)
.await;

let pre_prompt_tokens = if let Ok(query) = query {
initial_message.push_str("with new section. ");
let (model_state,) = query;
let deserialized: Result<_, _> = bincode::deserialize(&model_state[..]);
deserialized.unwrap_or_default()
} else {
if PRINT_SLACK_EVENTS {
println!("Starting new section");
}
initial_message.push_str("reusing section. ");
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_err(|e| format!("Error: {:?}", e))?
.as_secs() as i64;
sqlx::query(
"INSERT OR IGNORE INTO sessions
(channel, created_at, updated_at)
VALUES ($1, $2, $3);",
)
sqlx::query( "INSERT OR IGNORE INTO sessions (channel, created_at, updated_at) VALUES ($1, $2, $3);")
.bind(channel)
.bind(timestamp)
.bind(timestamp)
Expand All @@ -197,21 +188,28 @@ async fn process_slack_events(
Default::default()
};

initial_message
.push_str(format!("Current queue size: {}.", llm_model_sender.len()).as_str());
initial_message.push_str("\nNote: to delete the LLM chat section, send \"delete\".");

let form = multipart::Form::new()
.text("text", "Running the LLM...")
.text("text", initial_message)
.text("channel", channel.to_owned());
let _ = reqw_client
.post("https://slack.com/api/chat.postMessage")
.header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0))
.multipart(form)
.send()
.await;
tokio::spawn(
reqw_client
.post("https://slack.com/api/chat.postMessage")
.header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0))
.multipart(form)
.send(),
);

let (oneshot_tx, oneshot_rx) = oneshot::channel();
llm_model_sender
.send((text, pre_prompt_tokens, oneshot_tx))
.unwrap();
let (generated_text, next_pre_prompt_tokens) = oneshot_rx.await.unwrap();
let (generated_text, next_pre_prompt_tokens) = oneshot_rx
.await
.map_err(|e| format!("One-shot channel error: {e}"))?;

if PRINT_SLACK_EVENTS {
println!("Saving model state");
Expand All @@ -224,22 +222,20 @@ async fn process_slack_events(
.as_secs() as i64;
sqlx::query(
"INSERT INTO sessions
(channel, created_at, updated_at, model_state)
VALUES ($1, $2, $3, $4)
ON CONFLICT (channel)
DO UPDATE SET
model_state = EXCLUDED.model_state,
updated_at = EXCLUDED.updated_at;",
(channel, created_at, updated_at, model_state)
VALUES ($1, $2, $3, $4)
ON CONFLICT (channel)
DO UPDATE SET
model_state = EXCLUDED.model_state,
updated_at = EXCLUDED.updated_at;",
)
.bind(channel)
.bind(timestamp)
.bind(timestamp)
.bind(encoded)
.execute(&db_pool)
.await?;
"Reply from the LLM:\n".to_owned()
+ generated_text.as_str()
+ "\nNote: to delete the LLM chat section, send \"delete\""
"Reply from the LLM:\n".to_owned() + &generated_text[1..generated_text.len() - 4]
};

let form = multipart::Form::new()
Expand Down

0 comments on commit fe3bcb4

Please sign in to comment.