Skip to content

Commit fa01d91

Browse files
committed
GPT利用のcrateを変更して、最新モデルを使用するようにした。
summary機能を削除し不要なcrateを削除
1 parent 8937ca6 commit fa01d91

File tree

4 files changed

+62
-192
lines changed

4 files changed

+62
-192
lines changed

Cargo.toml

+1-4
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@ serde_yaml = "0.9.25"
1313
chrono = { version = "0.4.23", features = ["serde"] }
1414
time = "0.3.20"
1515
serde_json = "1.0.94"
16-
r2d2 = "0.8.10"
1716
rusqlite = "0.30.0"
1817
whatlang = "0.16.2"
19-
chat-gpt-rs = "1.3.0"
2018
dotenv = "0.15.0"
2119
regex = "1.7.1"
22-
mysql = "24.0.0"
23-
r2d2_mysql = "24.0.0"
2420
thiserror = "1.0.39"
21+
openai-api-rs = "4.0.5"

docker-compose.yml

-5
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,3 @@ services:
1111
working_dir: /var/bot/src
1212
command: bash -c "cargo run"
1313
restart: always
14-
15-
networks:
16-
default:
17-
name: nostify_default
18-
external: true

src/gpt.rs

+58-113
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,63 @@
11
use crate::config::AppConfig;
2-
use chat_gpt_rs::prelude::*;
32
use dotenv::dotenv;
3+
use std::error::Error;
44
use std::fs::File;
55
use std::time::Duration;
66
use std::{env, thread};
77
use tokio::time::timeout;
8+
use openai_api_rs::v1::api::Client;
9+
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
10+
use openai_api_rs::v1::common::GPT3_5_TURBO;
811

9-
pub async fn get_reply<'a>(personality: &'a str, user_text: &'a str, has_mention: bool) -> Result<String> {
12+
13+
pub async fn call_gpt(prompt: &str, user_text: &str) -> Result<String, Box<dyn Error>> {
14+
dotenv().ok();
15+
let api_key = env::var("OPEN_AI_API_KEY").expect("OPEN_AI_API_KEY is not set");
16+
let client = Client::new(api_key);
17+
let req = ChatCompletionRequest::new(
18+
GPT3_5_TURBO.to_string(),
19+
vec![
20+
chat_completion::ChatCompletionMessage {
21+
role: chat_completion::MessageRole::system,
22+
content: chat_completion::Content::Text(String::from(prompt)),
23+
name: None,
24+
},
25+
chat_completion::ChatCompletionMessage {
26+
role: chat_completion::MessageRole::user,
27+
content: chat_completion::Content::Text(String::from(user_text)),
28+
name: None,
29+
},
30+
],
31+
)
32+
.presence_penalty(-0.5)
33+
.frequency_penalty(0.0)
34+
.top_p(0.9);
35+
36+
let chat_completion_future = async {
37+
client.chat_completion(req)
38+
};
39+
40+
// タイムアウトを設定
41+
match timeout(Duration::from_secs(30), chat_completion_future).await {
42+
Ok(result) => match result {
43+
Ok(response) => {
44+
// 正常なレスポンスの処理
45+
match &response.choices[0].message.content {
46+
Some(content) => Ok(content.to_string()),
47+
None => Err("No content found in response".into()), // 適切なエラーメッセージを返す
48+
}
49+
},
50+
Err(e) => Err(e.into()), // APIErrorをBox<dyn Error>に変換
51+
},
52+
Err(_) => Err("Timeout after 30 seconds".into()),
53+
}
54+
}
55+
56+
pub async fn get_reply<'a>(personality: &'a str, user_text: &'a str, has_mention: bool) -> Result<String, Box<dyn Error>> {
1057
dotenv().ok();
1158
let file = File::open("../config.yml").unwrap();
1259
let config: AppConfig = serde_yaml::from_reader(file).unwrap();
1360
let answer_length = config.gpt.answer_length;
14-
let api_key = env::var("OPEN_AI_API_KEY").expect("OPEN_AI_API_KEY is not set");
15-
16-
let token = Token::new(&api_key);
17-
let api = Api::new(token);
1861

1962
let start_delimiter = "<<";
2063
let end_delimiter = ">>";
@@ -41,117 +84,19 @@ pub async fn get_reply<'a>(personality: &'a str, user_text: &'a str, has_mention
4184
prompt_temp = format!("これはあなたの人格です。'{personality}'\nこの人格を演じて次の行の文章に対して{answer_length}文字程度で返信してください。");
4285
}
4386
if !has_mention {
44-
prompt = format!("{prompt_temp}次の行の文章はSNSでの投稿です。あなたがたまたま見かけたものであなた宛の文章ではないのでその点に注意して解凍してください。")
87+
prompt = format!("{prompt_temp}次の行の文章はSNSでの投稿です。あなたがたまたま見かけたものであなた宛の文章ではないのでその点に注意して回答してください。")
4588
} else {
4689
prompt = prompt_temp
4790
}
4891

49-
let request = Request {
50-
model: Model::Gpt35Turbo,
51-
messages: vec![
52-
Message {
53-
role: "system".to_string(),
54-
content: prompt,
55-
},
56-
Message {
57-
role: "user".to_string(),
58-
content: user_text.to_string(),
59-
},
60-
],
61-
presence_penalty: Some(-0.5),
62-
frequency_penalty: Some(0.0),
63-
top_p: Some(0.9),
64-
65-
..Default::default()
66-
};
67-
// let response = api.chat(request).await?;
68-
let reply;
69-
let result = timeout(Duration::from_secs(30), api.chat(request)).await;
70-
match result {
71-
Ok(response) => {
72-
// 非同期処理が完了した場合の処理
73-
reply = response.unwrap().choices[0].message.content.clone();
74-
println!("{:?}", reply);
92+
match call_gpt(&prompt, &user_text.to_string()).await {
93+
Ok(reply) => {
94+
println!("Reply: {}", reply);
7595
Ok(reply)
76-
}
77-
Err(_) => {
78-
eprintln!("**********Timeout occurred while calling api.chat");
79-
reply = "".to_string();
80-
Ok(reply)
81-
}
82-
}
83-
}
84-
85-
use regex::Regex;
86-
87-
fn split_text(text: &str, max_length: usize) -> Vec<String> {
88-
let re = Regex::new(r"[^。\n]*[。\n]").unwrap();
89-
let sentences: Vec<&str> = re.find_iter(text).map(|m| m.as_str()).collect();
90-
91-
let mut result = Vec::new();
92-
let mut current = String::new();
93-
for sentence in sentences {
94-
if current.len() + sentence.len() > max_length {
95-
result.push(current.trim().to_string());
96-
current.clear();
97-
}
98-
current += sentence;
99-
}
100-
if !current.is_empty() {
101-
result.push(current.trim().to_string());
102-
}
103-
result
104-
}
105-
106-
pub async fn get_summary<'a>(text: &'a str) -> Result<String> {
107-
dotenv().ok();
108-
let api_key = env::var("OPEN_AI_API_KEY").expect("OPEN_AI_API_KEY is not set");
109-
110-
let token = Token::new(&api_key);
111-
let api = Api::new(token);
112-
113-
let prompt = format!(
114-
"あなたは優秀な新聞記者のお嬢様です。次の文章を読んで要約しお嬢様のような口調で日本語で10行にまとめてください。行頭には必ず'・'を入れて行末には必ず改行を入れてください。"
115-
);
116-
117-
let mut summary = String::from("");
118-
let split_texts = split_text(text, 2048);
119-
for _text in split_texts {
120-
loop {
121-
let request = Request {
122-
model: Model::Gpt35Turbo,
123-
messages: vec![
124-
Message {
125-
role: "system".to_string(),
126-
content: prompt.clone(),
127-
},
128-
Message {
129-
role: "user".to_string(),
130-
content: _text.to_string(),
131-
},
132-
],
133-
presence_penalty: Some(-0.5),
134-
frequency_penalty: Some(0.0),
135-
top_p: Some(0.9),
136-
137-
..Default::default()
138-
};
139-
let result = timeout(Duration::from_secs(30), api.chat(request)).await;
140-
match result {
141-
Ok(response) => {
142-
// 非同期処理が完了した場合の処理
143-
let _summary = response.unwrap().choices[0].message.content.clone();
144-
summary = format!("{}{}", summary, _summary);
145-
println!("summary:{}:{}", summary.len(), summary);
146-
break;
147-
}
148-
Err(_) => {
149-
eprintln!("**********Timeout occurred while calling api.chat");
150-
thread::sleep(Duration::from_secs(3));
151-
}
152-
}
153-
}
96+
},
97+
Err(e) => {
98+
println!("Error: {}", e);
99+
Ok("".to_string())
100+
},
154101
}
155-
summary = summary.replace("。・", "\n・");
156-
Ok(summary)
157102
}

src/main.rs

+3-70
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
mod config;
22
mod db;
3-
mod db_mysql;
43
mod gpt;
5-
use chrono::{DateTime, Utc};
4+
use chrono::Utc;
65
use config::AppConfig;
76
use dotenv::dotenv;
87
use nostr_sdk::prelude::*;
@@ -48,7 +47,7 @@ async fn is_follower(user_pubkey: &str, bot_secret_key: &str) -> Result<bool> {
4847
count += 1;
4948
println!("count:{:?}", count);
5049
if events.len() >= (config.relay_servers.read.len() / 2) ||
51-
count >= 10
50+
count >= 3
5251
{
5352
break;
5453
}
@@ -104,7 +103,7 @@ async fn get_kind0(target_pubkey: &str, bot_secret_key: &str) -> Result<Event> {
104103
count += 1;
105104
println!("count:{:?}", count);
106105
if events.len() >= (config.relay_servers.read.len() / 2) ||
107-
count >= 10
106+
count >= 3
108107
{
109108
break;
110109
}
@@ -260,72 +259,6 @@ async fn command_handler(
260259
&format!("データベースのkind 0の情報をブロードキャストしました"),
261260
)
262261
.await?;
263-
} else if lines[0].contains("summary") {
264-
let from = &lines[1];
265-
let to = &lines[2];
266-
let pool = db_mysql::connect().unwrap();
267-
let from_timestamp = db_mysql::to_unix_timestamp(&from).unwrap() - 9 * 60 * 60;
268-
let from_datetime = DateTime::<Utc>::from_utc(
269-
chrono::NaiveDateTime::from_timestamp_opt(from_timestamp, 0).unwrap(),
270-
Utc,
271-
);
272-
let from_datetime_str = from_datetime.format("%Y-%m-%d %H:%M:%S").to_string();
273-
let to_timestamp = db_mysql::to_unix_timestamp(&to).unwrap() - 9 * 60 * 60;
274-
let to_datetime = DateTime::<Utc>::from_utc(
275-
chrono::NaiveDateTime::from_timestamp_opt(to_timestamp, 0).unwrap(),
276-
Utc,
277-
);
278-
let to_datetime_str = to_datetime.format("%Y-%m-%d %H:%M:%S").to_string();
279-
let events = db_mysql::select_events(
280-
&pool,
281-
Kind::TextNote,
282-
&from_datetime_str,
283-
&to_datetime_str,
284-
);
285-
286-
if events.len() > 0 {
287-
let event_len = events.len();
288-
reply_to(
289-
&config,
290-
event.clone(),
291-
person.clone(),
292-
&format!("{from}〜{to}の{event_len}件の投稿のうち、日本語の投稿の要約を開始しますわ。しばらくお待ち遊ばせ。"),
293-
)
294-
.await?;
295-
}
296-
297-
let mut summary = String::from("");
298-
let mut event_count = 0;
299-
for event in events {
300-
let mut japanese: bool = false;
301-
if let Some(lang) = detect(&event.content) {
302-
match lang.lang() {
303-
Lang::Jpn => japanese = true,
304-
_ => (),
305-
}
306-
}
307-
if japanese
308-
&& !event.content.starts_with("lnbc")
309-
&& !event.content.contains("#まとめ除外")
310-
&& event.content.len() < 400
311-
{
312-
summary = format!("{}{}\n", summary, event.content);
313-
event_count += 1;
314-
}
315-
}
316-
while summary.len() > 1500 {
317-
summary = gpt::get_summary(&summary).await?;
318-
}
319-
print!("summary:{}", summary);
320-
reply_to(
321-
&config,
322-
event.clone(),
323-
person,
324-
&format!(
325-
"{from}〜{to}の日本語投稿{event_count}件の要約ですわ。\n{summary}\n#まとめ除外"
326-
),
327-
)
328-
.await?;
329262
}
330263
}
331264
}

0 commit comments

Comments
 (0)