Skip to content

Commit

Permalink
Add deepseek-r1 chat template (#3072)
Browse files Browse the repository at this point in the history
* Add deepseek-r1 chat template

* ut

* add deepseek v3
  • Loading branch information
AllentDan authored Jan 27, 2025
1 parent 552bf3a commit 894af4d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
28 changes: 28 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,34 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
return ret


@MODELS.register_module(name=['deepseek-r1'])
class Deepseek(BaseChatTemplate):

def __init__(self, user='<|User|>', assistant='<|Assistant|>', eoa='<|end▁of▁sentence|>', **kwargs):
super().__init__(user=user, assistant=assistant, eoa=eoa, **kwargs)

def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return '<|begin▁of▁sentence|>' + super().get_prompt(prompt, sequence_start)
return super().get_prompt(prompt, sequence_start)

def messages2prompt(self, messages, sequence_start=True, **kwargs):
if sequence_start and not isinstance(messages, str):
return '<|begin▁of▁sentence|>' + super().messages2prompt(messages, sequence_start, **kwargs)
return super().messages2prompt(messages, sequence_start, **kwargs)

@classmethod
def match(cls, model_path: str) -> Optional[str]:
"""Return the model_name that was registered to MODELS.
Args:
model_path (str): the model path used for matching.
"""
path = model_path.lower()
if 'deepseek-r1' in path or 'deepseek-v3' in path:
return 'deepseek-r1'


@MODELS.register_module(name='cogvlm')
class CogVLM(BaseChatTemplate):
"""Chat template of CogVLM model."""
Expand Down
36 changes: 36 additions & 0 deletions tests/test_lmdeploy/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,39 @@ def test_phi3(model_path_and_name):
ref = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
res = model.messages2prompt(messages)
assert res.startswith(ref)


@pytest.mark.parametrize('model_path_or_name', [
'deepseek-ai/DeepSeek-R1-Distill-Llama-8B',
'deepseek-ai/DeepSeek-R1-Distill-Llama-70B',
'deepseek-ai/DeepSeek-R1',
'deepseek-ai/DeepSeek-R1-Zero',
'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B',
'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
'deepseek-ai/DeepSeek-R1-Distill-Qwen-14B',
'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B',
'deepseek-ai/DeepSeek-V3',
])
def test_deepseek_r1(model_path_or_name):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path_or_name, trust_remote_code=True)
deduced_name = best_match_model(model_path_or_name)
chat_template = MODELS.get(deduced_name)()

messages = [{
'role': 'system',
'content': 'you are a helpful assistant'
}, {
'role': 'user',
'content': 'who are you'
}, {
'role': 'assistant',
'content': 'I am an AI'
}, {
'role': 'user',
'content': 'AGI is?'
}]
ref = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
lm_res = chat_template.messages2prompt(messages)
assert ref == lm_res

0 comments on commit 894af4d

Please sign in to comment.