Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/bema_for_reference_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ trainer = DPOTrainer(

trainer.train()
```

## DPOTrainer

[[autodoc]] experimental.bema_for_ref_model.DPOTrainer
- train
- save_model
- push_to_hub

## BEMACallback

[[autodoc]] experimental.bema_for_ref_model.BEMACallback
11 changes: 11 additions & 0 deletions docs/source/gfpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,14 @@ trainer = GFPOTrainer(
)
trainer.train()
```

## GFPOTrainer

[[autodoc]] experimental.gfpo.GFPOTrainer
- train
- save_model
- push_to_hub

## GFPOConfig

[[autodoc]] experimental.gfpo.GFPOConfig
15 changes: 15 additions & 0 deletions docs/source/grpo_with_replay_buffer.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,18 @@ previous_trainable_params = {n: param.clone() for n, param in trainer.model.name

trainer.train()
```

## GRPOWithReplayBufferTrainer

[[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer
- train
- save_model
- push_to_hub

## GRPOWithReplayBufferConfig

[[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig

## ReplayBuffer

[[autodoc]] experimental.grpo_with_replay_buffer.ReplayBuffer
7 changes: 7 additions & 0 deletions docs/source/gspo_token.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@ training_args = GRPOConfig(

> [!WARNING]
> To leverage GSPO-token, the user will need to provide the per-token advantage \\( \hat{A_{i,t}} \\) for each token \\( t \\) in the sequence \\( i \\) (i.e., make \\( \hat{A_{i,t}} \\) varies with \\( t \\)—which isn't the case here, \\( \hat{A_{i,t}}=\hat{A_{i}} \\)). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.

## GRPOTrainer

[[autodoc]] experimental.gspo_token.GRPOTrainer
- train
- save_model
- push_to_hub
2 changes: 1 addition & 1 deletion docs/source/merge_model_callback.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# MergeModelCallback

[[autodoc]] MergeModelCallback
[[autodoc]] experimental.merge_model_callback.MergeModelCallback
10 changes: 5 additions & 5 deletions trl/experimental/merge_model_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,12 @@ def create(self) -> "MergeConfiguration":
return self.create_merge_config_slerp()


def merge_models(config: MergeConfig, out_path: str):
def merge_models(config: "MergeConfiguration", out_path: str):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, good catch

"""
Merge two models using mergekit

Args:
config ([`MergeConfig`]): The merge configuration.
config (`MergeConfiguration`): The merge configuration.
out_path (`str`): The output path for the merged model.
"""
if not is_mergekit_available():
Expand All @@ -297,8 +297,8 @@ class MergeModelCallback(TrainerCallback):
on a merge configuration.

Args:
merge_config ([`MergeConfig`], *optional*):
Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used.
merge_config ([`experimental.merge_model_callback.MergeConfig`], *optional*):
Configuration used for the merging process. If not provided, the default [`~experimental.merge_model_callback.MergeConfig`] is used.
merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to merge the model at every checkpoint.
push_to_hub (`bool`, *optional*, defaults to `False`):
Expand All @@ -307,7 +307,7 @@ class MergeModelCallback(TrainerCallback):
Example:

```python
from trl.experiemental.merge_model_callback import MergeConfig, MergeModelCallback
from trl.experimental.merge_model_callback import MergeConfig, MergeModelCallback

config = MergeConfig()
merge_callback = MergeModelCallback(config)
Expand Down
Loading