1
+ from threading import Thread
2
+ from transformers import AutoModelForCausalLM , AutoTokenizer , TextIteratorStreamer , TextStreamer
3
+ from transformers .generation .stopping_criteria import StoppingCriteria
4
+ import streamlit as st
5
+
6
+ # 初始化模型和tokenizer
7
+ @st .cache_resource
8
+ def load_model ():
9
+ model_name = "MLZoo/Qwen2.5-1.5B-Instruct-DPO-bad-boy"
10
+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-1.5B-Instruct" )
11
+ model = AutoModelForCausalLM .from_pretrained (model_name )
12
+ return model , tokenizer
13
+
14
+ model , tokenizer = load_model ()
15
+
16
+ class StopOnTokens (StoppingCriteria ):
17
+ def __init__ (self , stop_ids ):
18
+ self .stop_ids = stop_ids
19
+
20
+ def __call__ (self , input_ids , scores , ** kwargs ):
21
+ # 检查最后一个生成的token是否是停止token
22
+ for stop_id in self .stop_ids :
23
+ if input_ids [0 ][- 1 ] == stop_id :
24
+ return True
25
+ return False
26
+
27
+ def generate_response_streaming (chat_history ):
28
+ # 构建对话历史
29
+ messages = [{"role" : "system" , "content" : "你是一个没素质的人,请用暴躁的语气回复我,尽量多说一些。" }]
30
+ # 添加历史对话
31
+ for msg in chat_history :
32
+ messages .append ({
33
+ "role" : "user" if msg ["role" ] == "user" else "assistant" ,
34
+ "content" : msg ["content" ]
35
+ })
36
+
37
+ input_text = tokenizer .apply_chat_template (messages , tokenize = False )
38
+ inputs = tokenizer (input_text , return_tensors = "pt" ).to (model .device )
39
+
40
+ # 使用 streamer 进行生成
41
+ streamer = TextIteratorStreamer (tokenizer , skip_special_tokens = True )
42
+
43
+ # 设置生成参数
44
+ generation_kwargs = {
45
+ "inputs" : inputs ["input_ids" ],
46
+ "max_length" : 2048 , # 增加最大长度以支持更长的对话
47
+ "temperature" : 0.7 ,
48
+ "top_p" : 0.9 ,
49
+ "do_sample" : True ,
50
+ "streamer" : streamer ,
51
+ "stopping_criteria" : [StopOnTokens ([tokenizer .eos_token_id ])],
52
+ }
53
+
54
+ # 在单独的线程中进行生成
55
+ thread = Thread (target = model .generate , kwargs = generation_kwargs )
56
+ thread .start ()
57
+
58
+ # 实时输出生成的文本
59
+ token_counts = 0
60
+
61
+ # 创建一个空的占位符
62
+ message_placeholder = st .empty ()
63
+ full_response = ""
64
+
65
+ for new_text in streamer :
66
+ if token_counts < 4 :
67
+ token_counts += 1
68
+ continue
69
+ full_response += new_text
70
+ # 更新显示的文本
71
+ message_placeholder .markdown (full_response + "▌" )
72
+
73
+ # 显示完整的回复
74
+ message_placeholder .markdown (full_response )
75
+ return full_response
76
+
77
+ # Streamlit界面设置
78
+ st .title ("暴躁AI哥 🤖" )
79
+ st .write ("我是DPO train出来的暴躁AI哥,有什么问题尽管问我!" )
80
+
81
+ # 初始化聊天历史
82
+ if "messages" not in st .session_state :
83
+ st .session_state .messages = []
84
+
85
+ # 显示聊天历史
86
+ for message in st .session_state .messages :
87
+ with st .chat_message (message ["role" ]):
88
+ st .markdown (message ["content" ])
89
+
90
+ # 接收用户输入
91
+ if prompt := st .chat_input ("在这里输入你的问题..." ):
92
+ # 添加用户消息到聊天历史
93
+ st .session_state .messages .append ({"role" : "user" , "content" : prompt })
94
+ with st .chat_message ("user" ):
95
+ st .markdown (prompt )
96
+
97
+ # 生成助手回复
98
+ with st .chat_message ("assistant" ):
99
+ response = generate_response_streaming (st .session_state .messages )
100
+ st .session_state .messages .append ({"role" : "assistant" , "content" : response })
0 commit comments