diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index da08288a78..189fecccc0 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -86,7 +86,6 @@ def get_kl_controller(algorithm_cfg: AlgorithmConfig): raise ValueError(f"Invalid KL controller type: {algorithm_cfg.kl_ctrl.type}") -@torch.no_grad() def compute_approx_kl( log_probs: torch.Tensor, log_probs_base: torch.Tensor, diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 5f56cc3229..91e38d3263 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -1011,12 +1011,13 @@ def apply_reward_kl_penalty( action_log_probs: torch.Tensor = data["action_log_probs"] # single batched computation - kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl( # type: ignore - action_log_probs, - base_action_log_probs, - loss_mask=loss_masks_all, - kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type, - ) + with torch.no_grad(): + kl: Float[torch.Tensor, "batch_size seqlen"] = compute_approx_kl( # type: ignore + action_log_probs, + base_action_log_probs, + loss_mask=loss_masks_all, + kl_estimator_type=self.cfg.trainer.algorithm.kl_estimator_type, + ) kl_max: Float[torch.Tensor, "batch_size"] = torch.max(kl.abs(), dim=-1)[0] # noqa: F821 kl_mean: Float[torch.Tensor, "batch_size"] = masked_mean(kl, loss_masks_all, dim=-1) # noqa: F821