Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ class TrainerConfig:
# if None, automatically set to ceil(2 * model.max_model_len / ulysses_sequence_parallel_size)
max_token_len_per_gpu: Optional[int] = None
ulysses_sequence_parallel_size: int = 1 # sp size
fix_actor_microbatch_loss_scale: bool = False # EXPERIMENTAL
# TODO: extract more train-related params from underlying trainer engine

save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED
Expand Down
5 changes: 5 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class Actor:
ppo_micro_batch_size_per_gpu: int = 1
use_dynamic_bsz: Optional[bool] = None
ppo_max_token_len_per_gpu: Optional[int] = None
fix_actor_microbatch_loss_scale: Optional[bool] = None # EXPERIMENTAL
grad_clip: Optional[float] = None
ppo_epochs: int = 1
shuffle: bool = False
Expand Down Expand Up @@ -427,6 +428,10 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.actor_rollout_ref.actor.ppo_max_token_len_per_gpu = (
config.trainer.max_token_len_per_gpu
)
if self.actor_rollout_ref.actor.fix_actor_microbatch_loss_scale is None:
self.actor_rollout_ref.actor.fix_actor_microbatch_loss_scale = (
config.trainer.fix_actor_microbatch_loss_scale
)
if self.actor_rollout_ref.actor.ulysses_sequence_parallel_size is None:
self.actor_rollout_ref.actor.ulysses_sequence_parallel_size = (
config.trainer.ulysses_sequence_parallel_size
Expand Down
34 changes: 28 additions & 6 deletions trinity/trainer/verl/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def update_policy(self, data: DataProto): # noqa: C901

mini_batches = data.split(self.config.ppo_mini_batch_size)

# EXPERIMENTAL: apply loss scale fix
loss_agg_mode = (
self.policy_loss_fn.loss_agg_mode
if hasattr(self.policy_loss_fn, "loss_agg_mode")
else "token-mean"
)
do_fix_actor_microbatch_loss_scale = self.config.fix_actor_microbatch_loss_scale and (
loss_agg_mode == "token-mean"
)

metrics = {}
for _ in range(self.config.ppo_epochs):
for batch_idx, mini_batch in enumerate(mini_batches):
Expand All @@ -104,6 +114,12 @@ def update_policy(self, data: DataProto): # noqa: C901
)
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

if do_fix_actor_microbatch_loss_scale:
# calculate the total number of response tokens in the minibatch
mini_batch_token_num = torch.sum(
mini_batch.batch["response_mask"].to(get_device_id())
).item()

self.actor_optimizer.zero_grad()

for micro_batch in micro_batches:
Expand Down Expand Up @@ -156,13 +172,19 @@ def update_policy(self, data: DataProto): # noqa: C901
)
policy_loss = policy_loss + kl_loss

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (
response_mask.shape[0] / self.config.ppo_mini_batch_size
)
# set loss scale for the microbatch
if not do_fix_actor_microbatch_loss_scale:
# original implementation of microbatch loss scale
if self.config.use_dynamic_bsz:
loss_scale = response_mask.shape[0] / self.config.ppo_mini_batch_size
else:
loss_scale = 1.0 / self.gradient_accumulation
else:
loss = policy_loss / self.gradient_accumulation
# EXPERIMENTAL: fix for token-mean loss aggregation
# scale microbatch loss according to the number of tokens (rather than sequences)
loss_scale = torch.sum(response_mask).item() / (mini_batch_token_num + 1e-6)

loss = policy_loss * loss_scale
loss.backward()

append_to_dict(metrics, micro_batch_metrics)
Expand Down