Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knowledge Distillation sampling #3185

Merged
merged 4 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/Instruction/采样.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,23 @@ prms = {'custom': CustomPRM}
请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。

> 注意:该脚本的实际效果和模型、数据、RM的质量强相关,因此仅作为样例出现,用户请自行修改该脚本并训练自己的RM和generator模型。

## 大模型蒸馏采样

SWIFT的sample支持使用OpenAI API的方式,用大模型蒸馏数据,如下示例:
```shell
OPENAI_API_KEY="your_api_key" \
swift sample \
--sampler_type distill \
--sampler_engine client \
--model deepseek-r1 \
--stream true \
--dataset tastelikefeet/competition_math#5 \
--num_return_sequences 1 \
--temperature 0.6 \
--top_p 0.95 \
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
```
在以上示例中,base_url和model分别是api地址和模型名称,stream表示发起请求的stream参数。

注意,对于Deepseek-R1系列模型,输出会被格式化为:`<think>{reasoning_content}</think>\n\n<answer>{content}</answer>`。
23 changes: 23 additions & 0 deletions docs/source_en/Instruction/Sample.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,26 @@ By dividing the process into two stages, only one model is loaded at a time, avo
Please refer to the [Reinforcement Fine-Tuning Script](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py). This script provides a practical example of using sampling for reinforcement fine-tuning.

> **Note:** The actual effectiveness of this script is strongly related to the quality of the model, data, and RM. Therefore, it is presented only as an example. Users should modify this script and train their own RM and generator models accordingly.

## Sampling From Large Model

SWIFT's sample supports using the OpenAI API to distill data with large models. Example:

```shell
OPENAI_API_KEY="your_api_key" \
swift sample \
--sampler_type distill \
--sampler_engine client \
--model deepseek-r1 \
--stream true \
--dataset tastelikefeet/competition_math#5 \
--num_return_sequences 1 \
--temperature 0.6 \
--top_p 0.95 \
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
```
In this example:

`base_url` and `model` represent the API endpoint and model name, respectively. `stream` indicates the stream parameter for the request.

Note: For Deepseek-R1 series models, the output will be formatted as:`<thinking>{reasoning_content}</thinking>\n\n<answer>{content}</answer>`.
2 changes: 1 addition & 1 deletion swift/llm/argument/sampling_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SamplingArguments(BaseArguments):

# sampler settings
# sample/mcts/dvts/xxx
sampler_type: Literal['sample', 'mcts'] = 'sample'
sampler_type: Literal['sample', 'mcts', 'distill'] = 'sample'
sampler_engine: Literal['pt', 'lmdeploy', 'vllm', 'no', 'client'] = 'pt'
output_dir: str = 'sample_output'
output_file: Optional[str] = None
Expand Down
147 changes: 147 additions & 0 deletions swift/llm/sampling/distill_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import os
from copy import deepcopy
from typing import List, Optional

from openai import OpenAI

from swift.llm.infer.protocol import InferRequest, RequestConfig
from swift.llm.sampling.vanilla_sampler import VanillaSampler
from .utils import get_messages_md5


class OpenAI_Engine():

def __init__(
self,
model: str,
stream: bool = False,
base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1',
api_key: str = '',
**kwargs,
):
self.model = model
self.stream = stream
self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs)

def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
):
resp_contents = []
for infer_request in infer_requests:
completion = self.client.chat.completions.create(
model=self.model,
messages=infer_request['messages'],
temperature=request_config.temperature,
top_p=request_config.top_p,
max_tokens=request_config.max_tokens,
stream=self.stream,
)
if self.stream:
reasoning_content = ''
content = ''
for chunk in completion:
chunk_choices = chunk.choices
if len(chunk_choices) == 0:
continue
reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr(
chunk_choices[0].delta, 'reasoning_content') else ''
answer_chunk = chunk_choices[0].delta.content
if reasoning_chunk:
reasoning_content += reasoning_chunk
elif answer_chunk:
content += answer_chunk
else:
if hasattr(completion.choices[0].message, 'reasoning_content'):
reasoning_content = completion.choices[0].message.reasoning_content
content = completion.choices[0].message.content
assert len(content) > 0, 'Empty completion'
if reasoning_content:
resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>'
else:
resp_content = content
resp_contents.append(resp_content)

return resp_contents


class DistillSampler(VanillaSampler):

def __init__(self, *args, **kwargs):
super(VanillaSampler, self).__init__(*args, **kwargs)
assert self.args.sampler_engine == 'client'
_Engine = OpenAI_Engine
self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs)
self.caches = self.read_cache()

def _prepare_model_tokenizer(self):
pass

def _prepare_template(self):
pass

def extract_choice(self, resp):
message = resp.choices[0].message
if hasattr(message, 'reasoning_content'):
reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>'
else:
reps_content = message.content
return reps_content

def generate(self, data):
resp_all = []
infer_requests = []
sent = 0
rows = self.convert_data_to_rows(data)
for idx, row in enumerate(rows):
row = deepcopy(row)
messages = row['messages']
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
continue
if self.args.system:
if messages[0]['role'] == 'system':
messages[0]['content'] = self.args.system
else:
messages.insert(0, {'role': 'system', 'content': self.args.system})
if messages[-1]['role'] == 'assistant':
messages = messages[:-1]

row['messages'] = messages
infer_request = row
for i in range(self.args.num_return_sequences):
infer_requests.append(deepcopy(infer_request))
sent += 1

request_config = RequestConfig(
max_tokens=self.args.max_new_tokens,
temperature=self.args.temperature,
top_k=self.args.top_k,
top_p=self.args.top_p,
)

resp_list = []
if len(infer_requests) > 0:
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)

_cur = 0
for idx, row in enumerate(rows):
row = deepcopy(row)
uuid = get_messages_md5(row)
if uuid in self.caches:
choices = self.caches[uuid]['choices']
if len(choices) == self.args.num_return_sequences:
row['choices'] = choices
resp_all.append(row)
continue

resps = row
resps['choices'] = []
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
resps['choices'].append(resp_list[j])
resp_all.append(resps)
_cur += 1
return resp_all
5 changes: 5 additions & 0 deletions swift/llm/sampling/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> No
elif self.args.sampler_type == 'mcts':
from swift.llm.sampling.mcts import MctsSampler
self.sampler = MctsSampler(self.args)
elif self.args.sampler_type == 'distill':
from swift.llm.sampling.distill_sampler import DistillSampler
self.sampler = DistillSampler(self.args)
else:
raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}')

def _get_dataset(self):
args = self.args
Expand Down
35 changes: 35 additions & 0 deletions tests/sample/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os


def test_client():
from swift.llm import sampling_main, SamplingArguments
import json
base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
api_key = os.environ.get('OPENAI_API_KEY')
engine_kwargs = json.dumps({
'base_url': base_url,
'api_key': api_key,
})
dataset = 'tastelikefeet/competition_math#5'
system = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed
within <think> </think> and <answer> </answer> tags, respectively,
i.e., <think> reasoning process here </think> <answer> answer here </answer>."""
args = SamplingArguments(
sampler_type='distill',
sampler_engine='client',
model='deepseek-r1',
dataset=dataset,
num_return_sequences=1,
stream=True,
system=system,
temperature=0.6,
top_p=0.95,
engine_kwargs=engine_kwargs,
)
sampling_main(args)


if __name__ == '__main__':
test_client()
Loading