Skip to content

Commit

Permalink
each session LLM on a thread
Browse files Browse the repository at this point in the history
  • Loading branch information
randommm committed Nov 21, 2023
1 parent 327116c commit 00e1f59
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 76 deletions.
14 changes: 12 additions & 2 deletions migrations/20231021093421_initial.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
CREATE TABLE "sessions" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"channel" text NOT NULL UNIQUE,
"channel" text NOT NULL,
"thread_ts" text NOT NULL,
"model_state" blob,
"created_at" integer NOT NULL,
"updated_at" integer NOT NULL
"updated_at" integer NOT NULL,
UNIQUE(channel, thread_ts)
);
CREATE TABLE "queue" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"text" text NOT NULL,
"channel" text NOT NULL,
"thread_ts" text NOT NULL,
"created_at" integer NOT NULL,
"leased_at" integer NOT NULL
);
PRAGMA journal_mode=WAL;
7 changes: 0 additions & 7 deletions migrations/20231120202003_queue.sql

This file was deleted.

148 changes: 87 additions & 61 deletions src/routes/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth
repeat_last_n
);
loop {
thread::scope(|s| {
let res = thread::scope(|s| {
s.spawn(|| {
thread_priority::set_current_thread_priority(
thread_priority::ThreadPriority::Min,
Expand Down Expand Up @@ -108,15 +108,16 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth

loop {
// async task to select a task from the queue
let (task_id, prompt_str, channel) = async_handle.block_on(async {
get_next_task(&db_pool)
.await
.map_err(|e| format!("Failed to get next task from queue: {e}"))
})?;
let (task_id, prompt_str, channel, thread_ts) =
async_handle.block_on(async {
get_next_task(&db_pool)
.await
.map_err(|e| format!("Failed to get next task from queue: {e}"))
})?;

// async task to get the state if it exists
let pre_prompt_tokens = async_handle.block_on(async {
get_session_state(&db_pool, &channel, &slack_oauth_token)
get_session_state(&db_pool, &channel, &thread_ts, &slack_oauth_token)
.await
.map_err(|e| format!("Failed to get session state: {e}"))
})?;
Expand Down Expand Up @@ -201,55 +202,74 @@ pub async fn start_llm_worker(db_pool: SqlitePool, slack_oauth_token: SlackOAuth
let encoded: Vec<u8> = bincode::serialize(&next_pre_prompt_tokens)
.map_err(|e| format!("Failed to encode model {e}"))?;

let _ = async_handle.block_on(async {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_err(|e| format!("Error: {:?}", e))?
.as_secs() as i64;
sqlx::query("DELETE FROM queue WHERE id = $1;")
.bind(task_id)
.execute(&db_pool)
.await?;
sqlx::query(
"INSERT INTO sessions
(channel, created_at, updated_at, model_state)
VALUES ($1, $2, $3, $4)
ON CONFLICT (channel)
async_handle
.block_on(async {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_err(|e| format!("Error: {:?}", e))?
.as_secs() as i64;
sqlx::query("DELETE FROM queue WHERE id = $1;")
.bind(task_id)
.execute(&db_pool)
.await?;
sqlx::query(
"INSERT INTO sessions
(channel, thread_ts, created_at, updated_at, model_state)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (channel, thread_ts)
DO UPDATE SET
model_state = EXCLUDED.model_state,
updated_at = EXCLUDED.updated_at;",
)
.bind(&channel)
.bind(now)
.bind(now)
.bind(encoded)
.execute(&db_pool)
.await?;

let reply_to_user = "Reply from the LLM:\n".to_owned()
+ &generated_text[1..generated_text.len() - 4];

let form = multipart::Form::new()
.text("text", reply_to_user)
.text("channel", channel.to_owned());

let reqw_response = reqwest::Client::new()
.post("https://slack.com/api/chat.postMessage")
.header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0))
.multipart(form)
.send()
)
.bind(&channel)
.bind(&thread_ts)
.bind(now)
.bind(now)
.bind(encoded)
.execute(&db_pool)
.await?;
reqw_response.text().await.map_err(|e| {
format!("Failed to read reqwest response body: {e}")
})?;
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
});
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
})
.unwrap_or_else(|e| {
println!("Failed to save model state:\n{e}");
});

async_handle
.block_on(async {
let reply_to_user = "Reply from the LLM:\n".to_owned()
+ &generated_text[1..generated_text.len() - 4];

let form = multipart::Form::new()
.text("text", reply_to_user)
.text("channel", channel.to_owned())
.text("thread_ts", thread_ts.clone());

let reqw_response = reqwest::Client::new()
.post("https://slack.com/api/chat.postMessage")
.header(
AUTHORIZATION,
format!("Bearer {}", slack_oauth_token.0),
)
.multipart(form)
.send()
.await?;
reqw_response.text().await.map_err(|e| {
format!("Failed to read reqwest response body: {e}")
})?;
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
})
.unwrap_or_else(|e| {
println!("Failed to send user message:\n{e}");
});
}

#[allow(unreachable_code)]
Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
});
})
.join()
});
println!("LLM worker thread exited with message: {res:?}, restarting in 5 seconds");
thread::sleep(std::time::Duration::from_secs(5));
}
});
}
Expand Down Expand Up @@ -294,16 +314,16 @@ fn format_size(size_in_bytes: usize) -> String {

async fn get_next_task(
db_pool: &SqlitePool,
) -> Result<(i64, String, String), Box<dyn std::error::Error + Send + Sync>> {
let (task_id, prompt_str, channel) = loop {
) -> Result<(i64, String, String, String), Box<dyn std::error::Error + Send + Sync>> {
let (task_id, prompt_str, channel, thread_ts) = loop {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_err(|e| format!("Error: {:?}", e))?
.as_secs() as i64;
let mut tx = db_pool.begin().await?;
match sqlx::query_as(
"
SELECT id,text,channel FROM queue
SELECT id,text,channel,thread_ts FROM queue
WHERE leased_at <= $1
ORDER BY created_at ASC
LIMIT 0,1
Expand All @@ -314,7 +334,7 @@ async fn get_next_task(
.await
{
Ok(res) => {
let (task_id, prompt_str, channel): (i64, String, String) = res;
let (task_id, prompt_str, channel, thread_ts) = res;

if sqlx::query(
"
Expand All @@ -329,7 +349,7 @@ async fn get_next_task(
.is_ok()
&& tx.commit().await.is_ok()
{
break (task_id, prompt_str, channel);
break (task_id, prompt_str, channel, thread_ts);
}
}
Err(_) => {
Expand All @@ -338,24 +358,27 @@ async fn get_next_task(
}
}
};
Ok((task_id, prompt_str, channel))
Ok((task_id, prompt_str, channel, thread_ts))
}

async fn get_session_state(
db_pool: &SqlitePool,
channel: &str,
thread_ts: &str,
slack_oauth_token: &SlackOAuthToken,
) -> Result<Vec<u32>, Box<dyn std::error::Error + Send + Sync>> {
let mut initial_message = "Running LLM ".to_owned();
let query: Result<(Vec<u8>,), _> =
sqlx::query_as(r#"SELECT model_state FROM sessions WHERE channel = $1;"#)
.bind(channel)
.fetch_one(db_pool)
.await;
let query: Result<(Vec<u8>,), _> = sqlx::query_as(
r#"SELECT model_state FROM sessions WHERE channel = $1 AND thread_ts = $2;"#,
)
.bind(channel)
.bind(thread_ts)
.fetch_one(db_pool)
.await;
let pre_prompt_tokens = if let Ok(query) = query {
initial_message.push_str("reusing section. ");
let (model_state,) = query;
let deserialized: Result<Vec<u32>, _> = bincode::deserialize(&model_state[..]);
let deserialized = bincode::deserialize(&model_state[..]);
deserialized.unwrap_or_default()
} else {
initial_message.push_str("with new section. ");
Expand All @@ -365,9 +388,11 @@ async fn get_session_state(
.as_secs() as i64;
sqlx::query(
r#"INSERT OR IGNORE INTO
sessions (channel, created_at, updated_at) VALUES ($1, $2, $3);"#,
sessions (channel, thread_ts, created_at, updated_at)
VALUES ($1, $2, $3, $4);"#,
)
.bind(channel)
.bind(thread_ts)
.bind(timestamp)
.bind(timestamp)
.execute(db_pool)
Expand All @@ -378,7 +403,8 @@ async fn get_session_state(
let reqw_client = reqwest::Client::new();
let form = multipart::Form::new()
.text("text", initial_message)
.text("channel", channel.to_owned());
.text("channel", channel.to_owned())
.text("thread_ts", thread_ts.to_owned());
tokio::spawn(
reqw_client
.post("https://slack.com/api/chat.postMessage")
Expand Down
28 changes: 22 additions & 6 deletions src/routes/pages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ async fn process_slack_events(
};
print!("From user {user} at channel {channel} and type {type_}, received message: {text}. ");

let thread_ts = event.get("event_ts").ok_or("event_ts not found on query")?;
let thread_ts = thread_ts.as_str().ok_or("event_ts is not a string")?;

let text = match Regex::new(r" ?<@.*> ?") {
Ok(pattern) if type_ == "app_mention" => {
let text = pattern.replace_all(text, " ");
Expand All @@ -142,28 +145,37 @@ async fn process_slack_events(
let reqw_client = reqwest::Client::new();

let reply_to_user = if text == "delete" || text == "\"delete\"" {
let _ = sqlx::query("DELETE FROM sessions WHERE channel = $1")
let _ = sqlx::query("DELETE FROM sessions WHERE channel = $1 AND thread_ts = $2")
.bind(channel)
.bind(thread_ts)
.execute(&db_pool)
.await;
let _ = sqlx::query("DELETE FROM queue WHERE channel = $1")
let _ = sqlx::query("DELETE FROM queue WHERE channel = $1 AND thread_ts = $2")
.bind(channel)
.bind(thread_ts)
.execute(&db_pool)
.await;

"Ok, the LLM section was deleted. A new message will start a fresh LLM section.".to_owned()
} else if text == "plot" || text == "\"plot\"" {
return plot_random_stuff(channel.to_owned(), slack_oauth_token.clone()).await;
return plot_random_stuff(
channel.to_owned(),
thread_ts.to_owned(),
slack_oauth_token.clone(),
)
.await;
} else {
let created_at = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_err(|e| format!("Error: {:?}", e))?
.as_secs() as i64;
sqlx::query(
"INSERT INTO queue (text, channel, created_at, leased_at) VALUES ($1, $2, $3, 0);",
"INSERT INTO queue (text, channel, thread_ts, created_at, leased_at)
VALUES ($1, $2, $3, $4, 0);",
)
.bind(text)
.bind(channel)
.bind(thread_ts)
.bind(created_at)
.execute(&db_pool)
.await?;
Expand All @@ -181,7 +193,9 @@ async fn process_slack_events(

let form = multipart::Form::new()
.text("text", reply_to_user)
.text("channel", channel.to_owned());
.text("channel", channel.to_owned())
.text("thread_ts", thread_ts.to_owned());

let reqw_response = reqw_client
.post("https://slack.com/api/chat.postMessage")
.header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0))
Expand All @@ -198,6 +212,7 @@ async fn process_slack_events(

pub async fn plot_random_stuff(
channel: String,
thread_ts: String,
slack_oauth_token: SlackOAuthToken,
) -> Result<(), AppError> {
let mut buffer_ = vec![0; 640 * 480 * 3];
Expand Down Expand Up @@ -268,7 +283,8 @@ pub async fn plot_random_stuff(
let form = multipart::Form::new()
.text("channels", channel)
.text("title", "A plot for ya")
.part("file", part);
.part("file", part)
.text("thread_ts", thread_ts);
let reqw_response = reqw_client
.post("https://slack.com/api/files.upload")
.header(AUTHORIZATION, format!("Bearer {}", slack_oauth_token.0))
Expand Down

0 comments on commit 00e1f59

Please sign in to comment.