Skip to content

Commit

Permalink
support Knowledge Distillation sampling (#3185)
Browse files Browse the repository at this point in the history
* tmp commit

* support distill_sampling

* update doc

* fix lint
  • Loading branch information
mi804 authored Feb 19, 2025
1 parent d3fab8c commit a5b33ce
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 1 deletion.
20 changes: 20 additions & 0 deletions docs/source/Instruction/采样.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,23 @@ prms = {'custom': CustomPRM}
请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。

> 注意:该脚本的实际效果和模型、数据、RM的质量强相关,因此仅作为样例出现,用户请自行修改该脚本并训练自己的RM和generator模型。
## 大模型蒸馏采样

SWIFT的sample支持使用OpenAI API的方式,用大模型蒸馏数据,如下示例:
```shell
OPENAI_API_KEY="your_api_key" \
swift sample \
--sampler_type distill \
--sampler_engine client \
--model deepseek-r1 \
--stream true \
--dataset tastelikefeet/competition_math#5 \
--num_return_sequences 1 \
--temperature 0.6 \
--top_p 0.95 \
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
```
在以上示例中,base_url和model分别是api地址和模型名称,stream表示发起请求的stream参数。

注意,对于Deepseek-R1系列模型,输出会被格式化为:`<think>{reasoning_content}</think>\n\n<answer>{content}</answer>`
23 changes: 23 additions & 0 deletions docs/source_en/Instruction/Sample.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,26 @@ By dividing the process into two stages, only one model is loaded at a time, avo
Please refer to the [Reinforcement Fine-Tuning Script](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py). This script provides a practical example of using sampling for reinforcement fine-tuning.

> **Note:** The actual effectiveness of this script is strongly related to the quality of the model, data, and RM. Therefore, it is presented only as an example. Users should modify this script and train their own RM and generator models accordingly.
## Sampling From Large Model

SWIFT's sample supports using the OpenAI API to distill data with large models. Example:

```shell
OPENAI_API_KEY="your_api_key" \
swift sample \
--sampler_type distill \
--sampler_engine client \
--model deepseek-r1 \
--stream true \
--dataset tastelikefeet/competition_math#5 \
--num_return_sequences 1 \
--temperature 0.6 \
--top_p 0.95 \
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
```
In this example:

`base_url` and `model` represent the API endpoint and model name, respectively. `stream` indicates the stream parameter for the request.

Note: For Deepseek-R1 series models, the output will be formatted as:`<thinking>{reasoning_content}</thinking>\n\n<answer>{content}</answer>`.
2 changes: 1 addition & 1 deletion swift/llm/argument/sampling_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SamplingArguments(BaseArguments):

# sampler settings
# sample/mcts/dvts/xxx
sampler_type: Literal['sample', 'mcts'] = 'sample'
sampler_type: Literal['sample', 'mcts', 'distill'] = 'sample'
sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no', 'client'] = 'pt'
output_dir: str = 'sample_output'
output_file: Optional[str] = None
Expand Down
147 changes: 147 additions & 0 deletions swift/llm/sampling/distill_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
from copy import deepcopy
from typing import List, Optional

from openai import OpenAI

from swift.llm.infer.protocol import InferRequest, RequestConfig
from swift.llm.sampling.vanilla_sampler import VanillaSampler
from .utils import get_messages_md5


class OpenAI_Engine():

def __init__(
self,
model: str,
stream: bool = False,
base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1',
api_key: str = '',
**kwargs,
):
self.model = model
self.stream = stream
self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs)

def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
):
resp_contents = []
for infer_request in infer_requests:
completion = self.client.chat.completions.create(
model=self.model,
messages=infer_request['messages'],
temperature=request_config.temperature,
top_p=request_config.top_p,
max_tokens=request_config.max_tokens,
stream=self.stream,
)
if self.stream:
reasoning_content = ''
content = ''
for chunk in completion:
chunk_choices = chunk.choices
if len(chunk_choices) == 0:
continue
reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr(
chunk_choices[0].delta, 'reasoning_content') else ''
answer_chunk = chunk_choices[0].delta.content
if reasoning_chunk:
reasoning_content += reasoning_chunk
elif answer_chunk:
content += answer_chunk
else:
if hasattr(completion.choices[0].message, 'reasoning_content'):
reasoning_content = completion.choices[0].message.reasoning_content
content = completion.choices[0].message.content
assert len(content) > 0, 'Empty completion'
if reasoning_content:
resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>'
else:
resp_content = content
resp_contents.append(resp_content)

return resp_contents


class DistillSampler(VanillaSampler):

def __init__(self, *args, **kwargs):
super(VanillaSampler, self).__init__(*args, **kwargs)
assert self.args.sampler_engine == 'client'
_Engine = OpenAI_Engine
self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs)
self.caches = self.read_cache()

def _prepare_model_tokenizer(self):
pass

def _prepare_template(self):
pass

def extract_choice(self, resp):
message = resp.choices[0].message
if hasattr(message, 'reasoning_content'):
reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>'
else:
reps_content = message.content
return reps_content

def generate(self, data):
resp_all = []
infer_requests = []
sent = 0
rows = self.convert_data_to_rows(data)
for idx, row in enumerate(rows):
row = deepcopy(row)
messages = row['messages']
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
continue
if self.args.system:
if messages[0]['role'] == 'system':
messages[0]['content'] = self.args.system
else:
messages.insert(0, {'role': 'system', 'content': self.args.system})
if messages[-1]['role'] == 'assistant':
messages = messages[:-1]

row['messages'] = messages
infer_request = row
for i in range(self.args.num_return_sequences):
infer_requests.append(deepcopy(infer_request))
sent += 1

request_config = RequestConfig(
max_tokens=self.args.max_new_tokens,
temperature=self.args.temperature,
top_k=self.args.top_k,
top_p=self.args.top_p,
)

resp_list = []
if len(infer_requests) > 0:
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)

_cur = 0
for idx, row in enumerate(rows):
row = deepcopy(row)
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
row['choices'] = choices
resp_all.append(row)
continue

resps = row
resps['choices'] = []
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
resps['choices'].append(resp_list[j])
resp_all.append(resps)
_cur += 1
return resp_all
5 changes: 5 additions & 0 deletions swift/llm/sampling/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> No
elif self.args.sampler_type == 'mcts':
from swift.llm.sampling.mcts import MctsSampler
self.sampler = MctsSampler(self.args)
elif self.args.sampler_type == 'distill':
from swift.llm.sampling.distill_sampler import DistillSampler
self.sampler = DistillSampler(self.args)
else:
raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}')

def _get_dataset(self):
args = self.args
Expand Down
35 changes: 35 additions & 0 deletions tests/sample/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os


def test_client():
from swift.llm import sampling_main, SamplingArguments
import json
base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
api_key = os.environ.get('OPENAI_API_KEY')
engine_kwargs = json.dumps({
'base_url': base_url,
'api_key': api_key,
})
dataset = 'tastelikefeet/competition_math#5'
system = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed
within <think> </think> and <answer> </answer> tags, respectively,
i.e., <think> reasoning process here </think> <answer> answer here </answer>."""
args = SamplingArguments(
sampler_type='distill',
sampler_engine='client',
model='deepseek-r1',
dataset=dataset,
num_return_sequences=1,
stream=True,
system=system,
temperature=0.6,
top_p=0.95,
engine_kwargs=engine_kwargs,
)
sampling_main(args)


if __name__ == '__main__':
test_client()

0 comments on commit a5b33ce

Please sign in to comment.