|
2 | 2 | from datetime import datetime
|
3 | 3 |
|
4 | 4 | from lagent.actions import ActionExecutor, BingBrowser
|
5 |
| -from lagent.llms import LMDeployPipeline |
6 | 5 |
|
7 |
| -from .mindsearch_agent import MindSearchAgent, MindSearchProtocol |
8 |
| -from .mindsearch_prompt import (FINAL_RESPONSE_CN, GRAPH_PROMPT_CN, |
9 |
| - searcher_context_template_cn, |
10 |
| - searcher_input_template_cn, |
11 |
| - searcher_system_prompt_cn) |
| 6 | +import src.agent.models as llm_factory |
| 7 | +from src.agent.mindsearch_agent import MindSearchAgent, MindSearchProtocol |
| 8 | +from src.agent.mindsearch_prompt import ( |
| 9 | + FINAL_RESPONSE_CN, FINAL_RESPONSE_EN, GRAPH_PROMPT_CN, GRAPH_PROMPT_EN, |
| 10 | + searcher_context_template_cn, searcher_context_template_en, |
| 11 | + searcher_input_template_cn, searcher_input_template_en, |
| 12 | + searcher_system_prompt_cn, searcher_system_prompt_en) |
12 | 13 |
|
13 |
| -llm = LMDeployPipeline(path='internlm/internlm2_5-7b', |
14 |
| - model_name='internlm2', |
15 |
| - meta_template=[ |
16 |
| - dict(role='system', api_role='system'), |
17 |
| - dict(role='user', api_role='user'), |
18 |
| - dict(role='assistant', api_role='assistant'), |
19 |
| - dict(role='environment', api_role='environment') |
20 |
| - ], |
21 |
| - top_p=0.8, |
22 |
| - top_k=1, |
23 |
| - temperature=0, |
24 |
| - max_new_tokens=8192, |
25 |
| - repetition_penalty=1.02, |
26 |
| - stop_words=['<|im_end|>']) |
| 14 | +LLM = {} |
27 | 15 |
|
28 | 16 |
|
29 |
| -def init_agent(lang='cn', model_format='pipeline'): |
| 17 | +def init_agent(lang='cn', model_format='internlm_server'): |
| 18 | + llm = LLM.get(model_format, None) |
| 19 | + if llm is None: |
| 20 | + llm_cfg = getattr(llm_factory, model_format) |
| 21 | + if llm_cfg is None: |
| 22 | + raise NotImplementedError |
| 23 | + llm = llm_cfg.pop('type')(**llm_cfg) |
| 24 | + LLM[model_format] = llm |
30 | 25 |
|
31 | 26 | agent = MindSearchAgent(
|
32 | 27 | llm=llm,
|
33 | 28 | protocol=MindSearchProtocol(meta_prompt=datetime.now().strftime(
|
34 | 29 | 'The current date is %Y-%m-%d.'),
|
35 |
| - interpreter_prompt=GRAPH_PROMPT_CN, |
36 |
| - response_prompt=FINAL_RESPONSE_CN), |
| 30 | + interpreter_prompt=GRAPH_PROMPT_CN |
| 31 | + if lang == 'cn' else GRAPH_PROMPT_EN, |
| 32 | + response_prompt=FINAL_RESPONSE_CN |
| 33 | + if lang == 'cn' else FINAL_RESPONSE_EN), |
37 | 34 | searcher_cfg=dict(
|
38 | 35 | llm=llm,
|
39 |
| - plugin=ActionExecutor( |
| 36 | + plugin_executor=ActionExecutor( |
40 | 37 | BingBrowser(
|
41 | 38 | api_key=os.environ.get('BING_API_KEY', 'YOUR_BING_API'))),
|
42 | 39 | protocol=MindSearchProtocol(
|
43 | 40 | meta_prompt=datetime.now().strftime(
|
44 | 41 | 'The current date is %Y-%m-%d.'),
|
45 |
| - plugin_prompt=searcher_system_prompt_cn, |
| 42 | + plugin_prompt=searcher_system_prompt_cn |
| 43 | + if lang == 'cn' else searcher_system_prompt_en, |
46 | 44 | ),
|
47 |
| - template=dict(input=searcher_input_template_cn, |
48 |
| - context=searcher_context_template_cn))) |
| 45 | + template=dict(input=searcher_input_template_cn |
| 46 | + if lang == 'cn' else searcher_input_template_en, |
| 47 | + context=searcher_context_template_cn |
| 48 | + if lang == 'cn' else searcher_context_template_en))) |
49 | 49 | return agent
|
0 commit comments