-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support Knowledge Distillation sampling (#3185)
* tmp commit * support distill_sampling * update doc * fix lint
- Loading branch information
Showing
6 changed files
with
231 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import os | ||
from copy import deepcopy | ||
from typing import List, Optional | ||
|
||
from openai import OpenAI | ||
|
||
from swift.llm.infer.protocol import InferRequest, RequestConfig | ||
from swift.llm.sampling.vanilla_sampler import VanillaSampler | ||
from .utils import get_messages_md5 | ||
|
||
|
||
class OpenAI_Engine(): | ||
|
||
def __init__( | ||
self, | ||
model: str, | ||
stream: bool = False, | ||
base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1', | ||
api_key: str = '', | ||
**kwargs, | ||
): | ||
self.model = model | ||
self.stream = stream | ||
self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs) | ||
|
||
def infer( | ||
self, | ||
infer_requests: List[InferRequest], | ||
request_config: Optional[RequestConfig] = None, | ||
): | ||
resp_contents = [] | ||
for infer_request in infer_requests: | ||
completion = self.client.chat.completions.create( | ||
model=self.model, | ||
messages=infer_request['messages'], | ||
temperature=request_config.temperature, | ||
top_p=request_config.top_p, | ||
max_tokens=request_config.max_tokens, | ||
stream=self.stream, | ||
) | ||
if self.stream: | ||
reasoning_content = '' | ||
content = '' | ||
for chunk in completion: | ||
chunk_choices = chunk.choices | ||
if len(chunk_choices) == 0: | ||
continue | ||
reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr( | ||
chunk_choices[0].delta, 'reasoning_content') else '' | ||
answer_chunk = chunk_choices[0].delta.content | ||
if reasoning_chunk: | ||
reasoning_content += reasoning_chunk | ||
elif answer_chunk: | ||
content += answer_chunk | ||
else: | ||
if hasattr(completion.choices[0].message, 'reasoning_content'): | ||
reasoning_content = completion.choices[0].message.reasoning_content | ||
content = completion.choices[0].message.content | ||
assert len(content) > 0, 'Empty completion' | ||
if reasoning_content: | ||
resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>' | ||
else: | ||
resp_content = content | ||
resp_contents.append(resp_content) | ||
|
||
return resp_contents | ||
|
||
|
||
class DistillSampler(VanillaSampler): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(VanillaSampler, self).__init__(*args, **kwargs) | ||
assert self.args.sampler_engine == 'client' | ||
_Engine = OpenAI_Engine | ||
self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs) | ||
self.caches = self.read_cache() | ||
|
||
def _prepare_model_tokenizer(self): | ||
pass | ||
|
||
def _prepare_template(self): | ||
pass | ||
|
||
def extract_choice(self, resp): | ||
message = resp.choices[0].message | ||
if hasattr(message, 'reasoning_content'): | ||
reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>' | ||
else: | ||
reps_content = message.content | ||
return reps_content | ||
|
||
def generate(self, data): | ||
resp_all = [] | ||
infer_requests = [] | ||
sent = 0 | ||
rows = self.convert_data_to_rows(data) | ||
for idx, row in enumerate(rows): | ||
row = deepcopy(row) | ||
messages = row['messages'] | ||
uuid = get_messages_md5(row) | ||
if uuid in self.caches: | ||
choices = self.caches[uuid]['choices'] | ||
if len(choices) == self.args.num_return_sequences: | ||
continue | ||
if self.args.system: | ||
if messages[0]['role'] == 'system': | ||
messages[0]['content'] = self.args.system | ||
else: | ||
messages.insert(0, {'role': 'system', 'content': self.args.system}) | ||
if messages[-1]['role'] == 'assistant': | ||
messages = messages[:-1] | ||
|
||
row['messages'] = messages | ||
infer_request = row | ||
for i in range(self.args.num_return_sequences): | ||
infer_requests.append(deepcopy(infer_request)) | ||
sent += 1 | ||
|
||
request_config = RequestConfig( | ||
max_tokens=self.args.max_new_tokens, | ||
temperature=self.args.temperature, | ||
top_k=self.args.top_k, | ||
top_p=self.args.top_p, | ||
) | ||
|
||
resp_list = [] | ||
if len(infer_requests) > 0: | ||
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config) | ||
|
||
_cur = 0 | ||
for idx, row in enumerate(rows): | ||
row = deepcopy(row) | ||
uuid = get_messages_md5(row) | ||
if uuid in self.caches: | ||
choices = self.caches[uuid]['choices'] | ||
if len(choices) == self.args.num_return_sequences: | ||
row['choices'] = choices | ||
resp_all.append(row) | ||
continue | ||
|
||
resps = row | ||
resps['choices'] = [] | ||
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)): | ||
resps['choices'].append(resp_list[j]) | ||
resp_all.append(resps) | ||
_cur += 1 | ||
return resp_all |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import os | ||
|
||
|
||
def test_client(): | ||
from swift.llm import sampling_main, SamplingArguments | ||
import json | ||
base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1' | ||
api_key = os.environ.get('OPENAI_API_KEY') | ||
engine_kwargs = json.dumps({ | ||
'base_url': base_url, | ||
'api_key': api_key, | ||
}) | ||
dataset = 'tastelikefeet/competition_math#5' | ||
system = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. | ||
The assistant first thinks about the reasoning process in the mind and then provides the user | ||
with the answer. The reasoning process and answer are enclosed | ||
within <think> </think> and <answer> </answer> tags, respectively, | ||
i.e., <think> reasoning process here </think> <answer> answer here </answer>.""" | ||
args = SamplingArguments( | ||
sampler_type='distill', | ||
sampler_engine='client', | ||
model='deepseek-r1', | ||
dataset=dataset, | ||
num_return_sequences=1, | ||
stream=True, | ||
system=system, | ||
temperature=0.6, | ||
top_p=0.95, | ||
engine_kwargs=engine_kwargs, | ||
) | ||
sampling_main(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_client() |