Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions verl/trainer/config/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
data:
tokenizer: null
train_files: /root/data/gsm8k/train.parquet
val_files: /root/data/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 8
val_batch_size: 8
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: False
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
filter_overlong_prompts_workers: 1

actor_rollout_ref:
hybrid_engine: True
exchange_size: 1e9
model:
path: Qwen/Qwen2-7B-Instruct
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True
trust_remote_code: True
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 4
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: 1 # for dynamic bsz
use_dynamic_bsz: True
ppo_max_token_len_per_gpu: 32768 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 0.5
clip_ratio: 0.2
clip_ratio_low: 0.2
clip_ratio_high: 0.2
clip_ratio_c: 3.0
loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean"
entropy_coeff: 0.0
use_kl_loss: True # True for GRPO
kl_loss_coef: 0.0001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
multi_turn: True
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: True
optimizer_offload: True
fsdp_size: -1
ref:
fsdp_config:
param_offload: True
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: sglang_async
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
temperature: ${.sampling_params.temperature} # this is currently ignored
gpu_memory_utilization: 0.8
enable_memory_saver: False
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 4
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # could get higher throughput
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 2
multi_turn: ${actor_rollout_ref.actor.multi_turn}
max_turns: 3
plugin_browser: False
path: ${actor_rollout_ref.model.path}
sampling_params:
temperature: 0.8
max_new_tokens: 192
stop: []
val_kwargs:
# sampling parameters for validation
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1.0
temperature: 0
n: 1
do_sample: False # default eager for validation
tool_kwargs:
tool_config_path: "gsm8k_tool_config.yaml"

critic:
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: Qwen/Qwen2-7B-Instruct
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: True
use_remove_padding: False
fsdp_config:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space

reward_model:
enable: False

algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: grpo
use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001

trainer:
hybrid_engine: True
total_epochs: 3
total_training_steps: null
project_name: gsm8k_async_rl
experiment_name: qwen7b_sft2_16k_t08_n8_v6
logger: [ 'console', 'wandb' ]
val_generations_to_log_to_wandb: 0
nnodes: 1
n_gpus_per_node: 4
save_freq: 100
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
resume_from_path: False
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: /workspace/gsm8k/ckpt/${trainer.project_name}/${trainer.experiment_name}
val_before_train: False
balance_batch: False
15 changes: 15 additions & 0 deletions verl/trainer/config/gsm8k_tool_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
tools:
- class_name: "tool.gsm8k_tool.Gsm8kTool"
config:
name: "calc_gsm8k_reward"
description: ""A tool for calculating the reward of gsm8k"
parameters:
type: "object"
properties:
response:
type: "string"
description: "The model's response to the GSM8K math problem"
ground_truth:
type: "string"
description: "The ground truth answer to the GSM8K math problem"
required: ["response", "ground_truth"]
16 changes: 11 additions & 5 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,19 @@ def _check_resource_available(self):
from verl.utils.torch_functional import masked_mean


def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'):
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl', multi_turn=False):
responses = data.batch['responses']
response_length = responses.size(1)
token_level_scores = data.batch['token_level_scores']
batch_size = data.batch.batch_size[0]
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]


if multi_turn:
loss_mask = data.batch['loss_mask']
response_mask = loss_mask[:, -response_length:]
else:
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]

# compute kl between ref_policy and current policy
# When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'],
Expand Down Expand Up @@ -886,7 +891,8 @@ def fit(self):
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl_in_reward,
kl_penalty=self.config.algorithm.kl_penalty)
kl_penalty=self.config.algorithm.kl_penalty,
multi_turn=self.config.actor_rollout_ref.actor.get('multi_turn', False))
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
Expand Down
4 changes: 3 additions & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ def _build_rollout(self, trust_remote_code=False):
device_mesh=rollout_device_mesh)
log_gpu_memory_usage('After building sharding manager', logger=None)


else:
raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported")

return rollout, rollout_sharding_manager

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
Expand Down
26 changes: 26 additions & 0 deletions verl/workers/rollout/sglang_rollout/async_sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,32 @@ def __init__(
"""
super().__init__()
self.config = config

if config.get("tool_kwargs") and config.tool_kwargs.get("tools_config_file"):
import sys
import importlib.util

tool_list = []

tools_config_path = config.tool_kwargs.get("tools_config_file")
tools_config = OmegaConf.load(tools_config_path)

for tool_config in tools_config.tools:
cls_name = tool_config.class_name
module_name, class_name = cls_name.rsplit(".", 1)

if module_name not in sys.modules:
spec = importlib.util.find_spec(module_name)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
else:
module = sys.modules[module_name]

tool_cls = getattr(module, class_name)
tool = tool_cls(config=OmegaConf.to_container(tool_config.config, resolve=True))
tool_list.append(tool)

if tool_list is not None:
self._tool_schemas = [tool.get_openai_tool_schema() for tool in tool_list]
self._tool_map = {tool.name: tool for tool in tool_list}
Expand Down