diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 666f26be0d0..6b279fc8903 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -196,6 +196,9 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te Returns: DataProto: torch.Tensor: the log_prob tensor """ + prev_modes = [m.training for m in self.actor_module] + for module in self.actor_module: + module.eval() use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) micro_batch_size = data.meta_info.get("micro_batch_size", None) max_token_len = data.meta_info.get("max_token_len", None) @@ -306,6 +309,8 @@ def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): # add empty cache after each compute get_torch_device().empty_cache() + for module, mode in zip(self.actor_module, prev_modes, strict=False): + module.train(mode) return log_probs, entropys, layers_topk_idx def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index 781f92b0cb4..ecc166cd495 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -87,6 +87,9 @@ def _validate_config(self, config) -> None: @GPUMemoryLogger("megatron critic", logger=logger) def compute_values(self, data: DataProto) -> DataProto: + prev_modes = [m.training for m in self.critic_module] + for module in self.critic_module: + module.eval() responses = data.batch["responses"] attention_mask = data.batch["attention_mask"] use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) @@ -139,6 +142,8 @@ def compute_values(self, data: DataProto) -> DataProto: # add empty cache after each compute get_torch_device().empty_cache() + for module, mode in zip(self.critic_module, prev_modes, strict=False): + module.train(mode) return values def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: