Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO: AssertionError: backward pass is invalid for module in evaluation mode #3107

Open
JoyTim-777 opened this issue Feb 14, 2025 · 3 comments

Comments

@JoyTim-777
Copy link

在使用非vllm+peft的时候报错:AssertionError: backward pass is invalid for module in evaluation mode

@Jintao-Huang
Copy link
Collaborator

有完整的报错信息嘛,看看是哪里抛出来的

@JoyTim-777
Copy link
Author

报错信息:
File "swift/cli/rlhf.py", line 5, in
rlhf_main()
File "swift/llm/train/rlhf.py", line 96, in rlhf_main
return SwiftRLHF(args).main()
File "swift/llm/base.py", line 46, in main
result = self.run()
File "swift/llm/train/sft.py", line 137, in run
return self.train(trainer)
File "swift/llm/train/sft.py", line 196, in train
trainer.train(trainer.args.resume_from_checkpoint)
File "swift/trainers/mixin.py", line 262, in train
res = super().train(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2171, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2531, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3712, in training_step
self.accelerator.backward(loss, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2240, in backward
self.deepspeed_engine_wrapped.backward(loss, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/deepspeed.py", line 246, in backward
self.engine.backward(loss, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2020, in backward
self.optimizer.backward(loss, retain_graph=retain_graph)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2247, in backward
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
scaled_loss.backward(retain_graph=retain_graph)
File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py", line 347, in backward
_engine_run_backward(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 355, in backward
ctx.pre_backward_function(ctx.module)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 336, in _run_before_backward_function
self.pre_sub_module_backward_function(sub_module)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 466, in pre_sub_module_backward_function
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
AssertionError: backward pass is invalid for module in evaluation mode

@bonre
Copy link

bonre commented Feb 14, 2025

一样的问题,看起来是pt和zero3有冲突?zero2是正常的。
用的是下午修复了zero3问题的swift版本。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants