-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[BREAKING][model, data] feat: add support for Mistral3 #2338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c692f8c
037862f
2c900df
bd38488
4148d2d
c0e62cf
40c4600
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| set -x | ||
| ENGINE=${1:-vllm} | ||
| # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: | ||
| # export VLLM_ATTENTION_BACKEND=XFORMERS | ||
| echo $HOME | ||
| train_files=$HOME/data/geo3k/train.parquet | ||
| test_files=$HOME/data/geo3k/test.parquet | ||
| export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True | ||
|
|
||
| python3 -m verl.trainer.main_ppo \ | ||
| algorithm.adv_estimator=grpo \ | ||
| data.train_files=${train_files} \ | ||
| data.val_files=${test_files} \ | ||
| data.train_batch_size=256 \ | ||
| data.max_prompt_length=8196 \ | ||
| data.max_response_length=2048 \ | ||
| data.filter_overlong_prompts=True \ | ||
| data.truncation='error' \ | ||
| data.image_key=images \ | ||
| actor_rollout_ref.model.path=mistralai/Mistral-Small-3.1-24B-Instruct-2503 \ | ||
| actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
| actor_rollout_ref.model.use_remove_padding=True \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=128 \ | ||
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ | ||
| actor_rollout_ref.actor.use_kl_loss=True \ | ||
| actor_rollout_ref.actor.kl_loss_coef=0.01 \ | ||
| actor_rollout_ref.actor.kl_loss_type=low_var_kl \ | ||
| actor_rollout_ref.actor.entropy_coeff=0 \ | ||
| actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=True \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ | ||
| actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \ | ||
| actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ | ||
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=5 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ | ||
| actor_rollout_ref.ref.strategy=fsdp2 \ | ||
| actor_rollout_ref.actor.strategy=fsdp2 \ | ||
| actor_rollout_ref.ref.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \ | ||
| actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16 \ | ||
| actor_rollout_ref.ref.fsdp_config.param_offload=True \ | ||
| actor_rollout_ref.rollout.name=$ENGINE \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ | ||
| actor_rollout_ref.rollout.enable_chunked_prefill=False \ | ||
| actor_rollout_ref.rollout.enforce_eager=False \ | ||
| actor_rollout_ref.rollout.free_cache_engine=False \ | ||
| actor_rollout_ref.rollout.n=5 \ | ||
| actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=5 \ | ||
| algorithm.use_kl_in_reward=False \ | ||
| trainer.critic_warmup=0 \ | ||
| trainer.logger=['console','wandb'] \ | ||
| trainer.project_name='verl_grpo_example_geo3k' \ | ||
| trainer.experiment_name='mistral3_1_24b_ai' \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes=1 \ | ||
| trainer.save_freq=20 \ | ||
| trainer.test_freq=5 \ | ||
| trainer.total_epochs=15 $@ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the contribution. Currently when adding a new model, the adaptors are scattered around and many files are touch. I am thinking maybe reorganzing the files need to be changed so that we can have one folder per model. For instance: And in each model folder, the structure is like below (take mistral3 as the example): what do you think? cc @hiyouga @Fazziekey BTW I am also not sure if we want to have @ISEEKYAN what do you think? Similarly, tests can be standardized: With a better code structure it will be easier to write a new model onboarding documentation and let the community add new SOTA models.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a unified model unit test like For megatron, LLMs of different archs share the same |
||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from collections import defaultdict | ||
| from typing import List | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| COLLATE_FN_MANAGER_REGISTRY = {} | ||
|
|
||
|
|
||
| def _pad_for_batching( | ||
| pixel_values: List[torch.Tensor], | ||
| image_sizes: List[List[int]], | ||
| ): | ||
| """ | ||
| Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. | ||
| Args: | ||
| pixel_values (`List[torch.Tensor]`): | ||
| An array of pixel values of each images of shape (`batch_size`, `channels`, `height`, `width`) | ||
| image_sizes (`List[List[int]]`): | ||
| A list of sizes for each image in `pixel_values` in (height, width) format. | ||
| Returns: | ||
| List[`torch.Tensor`]: The padded images. | ||
| """ | ||
| max_shape = (max([size[0] for size in image_sizes]), max([size[1] for size in image_sizes])) | ||
| pixel_values = [ | ||
| torch.nn.functional.pad(image, pad=(0, max_shape[1] - size[1], 0, max_shape[0] - size[0])).unsqueeze(0) | ||
| for image, size in zip(pixel_values, image_sizes) | ||
| ] | ||
| return pixel_values | ||
|
|
||
|
|
||
| def register_collate_fn(name): | ||
| """Decorator to register a reward manager class with a given name. | ||
|
|
||
| Args: | ||
| name: `(str)` | ||
| The name of the reward manager. | ||
| """ | ||
|
|
||
| def decorator(cls): | ||
| if name in COLLATE_FN_MANAGER_REGISTRY and COLLATE_FN_MANAGER_REGISTRY[name] != cls: | ||
| raise ValueError( | ||
| f"Collate function manager {name} has already been registered: " | ||
| f"{COLLATE_FN_MANAGER_REGISTRY[name]} vs {cls}" | ||
| ) | ||
| COLLATE_FN_MANAGER_REGISTRY[name] = cls | ||
| return cls | ||
|
|
||
| return decorator | ||
|
|
||
|
|
||
| def get_collate_fn_manager_cls(name): | ||
| """Get the collate function manager class with a given name. | ||
|
|
||
| Args: | ||
| name: `(str)` | ||
| The name of the collate function manager. | ||
|
|
||
| Returns: | ||
| `(type)`: The collate function manager class. | ||
| """ | ||
| if name not in COLLATE_FN_MANAGER_REGISTRY: | ||
| default_collate_fn = COLLATE_FN_MANAGER_REGISTRY.get("default", None) | ||
| if default_collate_fn is None: | ||
| raise ValueError(f"Unknown collate function manager: {name}") | ||
| return default_collate_fn | ||
|
|
||
| return COLLATE_FN_MANAGER_REGISTRY[name] | ||
|
|
||
|
|
||
| @register_collate_fn("default") | ||
| def collate_fn(data_list: List[dict]) -> dict: | ||
| """ | ||
| Collate a batch of sample dicts into batched tensors and arrays. | ||
|
|
||
| Args: | ||
| data_list: List of dicts mapping feature names to torch.Tensor or other values. | ||
|
|
||
| Returns: | ||
| Dict where tensor entries are stacked into a torch.Tensor of shape | ||
| (batch_size, *dims) and non-tensor entries are converted to | ||
| np.ndarray of dtype object with shape (batch_size,). | ||
| """ | ||
| tensors = defaultdict(list) | ||
| non_tensors = defaultdict(list) | ||
|
|
||
| for data in data_list: | ||
| for key, val in data.items(): | ||
| if isinstance(val, torch.Tensor): | ||
| tensors[key].append(val) | ||
| else: | ||
| non_tensors[key].append(val) | ||
|
|
||
| for key, val in tensors.items(): | ||
| tensors[key] = torch.stack(val, dim=0) | ||
|
|
||
| for key, val in non_tensors.items(): | ||
| non_tensors[key] = np.array(val, dtype=object) | ||
|
|
||
| return {**tensors, **non_tensors} | ||
|
|
||
|
|
||
| @register_collate_fn("PixtralProcessor") | ||
| def collate_fn_for_pixtral(data_list: List[dict]) -> dict: | ||
| """ | ||
| Collate a batch of sample dicts into batched tensors and arrays. | ||
|
|
||
| Args: | ||
| data_list: List of dicts mapping feature names to torch.Tensor or other values. | ||
|
|
||
| Returns: | ||
| Dict where tensor entries are stacked into a torch.Tensor of shape | ||
| (batch_size, *dims) and non-tensor entries are converted to | ||
| np.ndarray of dtype object with shape (batch_size,). | ||
| """ | ||
| tensors = defaultdict(list) | ||
| non_tensors = defaultdict(list) | ||
|
|
||
| for data in data_list: | ||
| for key, val in data.items(): | ||
| if isinstance(val, torch.Tensor): | ||
| tensors[key].append(val) | ||
| else: | ||
| non_tensors[key].append(val) | ||
|
|
||
| for key, val in tensors.items(): | ||
| tensors[key] = torch.stack(val, dim=0) | ||
|
|
||
| for key, val in non_tensors.items(): | ||
| if key == "multi_modal_inputs": | ||
| val = [ele for ele in val if ele] | ||
| if not val: | ||
| continue | ||
| pixel_values = [v["pixel_values"][0] for v in val] | ||
| image_sizes = [v["image_sizes"][0] for v in val] | ||
| pixel_values = _pad_for_batching(pixel_values, image_sizes) | ||
| for v, pixel_value in zip(val, pixel_values): | ||
| v["pixel_values"] = pixel_value | ||
| non_tensors[key] = np.array(val, dtype=object) | ||
|
|
||
| return {**tensors, **non_tensors} | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This configuration seems not take effect in code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, This is actually supposed to be overridden by the parameter passed in through the script.

look at this line:
actor_rollout_ref.actor.fsdp_config.wrap_policy.transformer_layer_cls_to_wrap="['MistralDecoderLayer']" \
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, it's here: https://github.com/volcengine/verl/blob/main/verl/utils/fsdp_utils.py#L90-L92
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right!