-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix LightningModule step methods bypassing DDP wrapper in Fabric #17424
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
) Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 0ee71d6)
) Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 0ee71d6)
) Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 0ee71d6)
) Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 0ee71d6)
warning_cache.warn( | ||
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the" | ||
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in" | ||
f" `.backward()`. You should pass your inputs through `{type(self._original_module)}.forward()`.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was __name__
missed here? If so, can you include the change in #17531?
|
||
def _validate_method_access(self, name: str, attribute: Any) -> None: | ||
if inspect.ismethod(attribute) and self._forward_module != self._original_module: | ||
warning_cache.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed we get some false positives from this in CI:
/__w/2/s/src/lightning_fabric/wrappers.py:158: PossibleUserWarning: You are calling the method `ConvNet.get_dataloader()` from outside the model. This will bypass the wrapper from the strategy and result in incorrect behavior in `.backward()`. You should pass your inputs through `<class 'tests_fabric.parity.models.ConvNet'>.forward()`.
warning_cache.warn(
/__w/2/s/src/lightning_fabric/wrappers.py:158: PossibleUserWarning: You are calling the method `ConvNet.get_loss_function()` from outside the model. This will bypass the wrapper from the strategy and result in incorrect behavior in `.backward()`. You should pass your inputs through `<class 'tests_fabric.parity.models.ConvNet'>.forward()`.
warning_cache.warn(
The pattern used in the test seems natural. I'm thinking if we should relax this or reword it to make it clearer that this could be okay if you aren't running a computation inside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was thinking the same. Sometimes it is ok, sometimes not.
Additionally, to reduce the message spam if lots of different function calls happen, we could show the warning only for the first use but not for others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might make it useless, as often the first call will be for something that's not a problem like get_dataloader
above.
I think the simplest improvement would be to clarify the warning wording so that its easier to understand that this might be a false positive. "You should pass your inputs..." makes it sound like you always need to do this.
Hello, I am trying to understand how exactly to setup Fabric to train properly with DDP. I am writing code for distributed reinforcement learning and as such the modules contain arbitrary submodules. In general, there are two types of calls I make to the model: 1) interacting with the environment, which does not require gradients and 2) the optimization of the modules, which does require gradients. For case 1) I assume that everything works as intended, even though I had to do some manual device placement after upgrading from fabric 2.0.0 to 2.0.2. Before that, the same code worked without these changes. My main question is about the case 2). As far as I understand, some calls are now redirected to a fake The source of my confusion is that in the models I am implementing, there is no natural, single I have a training loop, which, after some interactions, calls this def optimize(self, optimizers) -> TensorDict:
self.step += 1
self.alpha = self.alpha.to(self.device)
value_logs = self.optimize_value(optimizers["value"])
policy_logs = self.optimize_policy(optimizers["policy"])
return self.as_tensordict(**policy_logs, **value_logs) Because I am not exactly sure which pattern is correct, I started redirecting calls to the submodules via the code below. The downside is that this makes everything more ugly, verbose and error-prone, so I would like to avoid having this. def forward(self, inputs, module) -> TensorDict:
if module == "policy":
return self.action(inputs)
if module == "q":
return self.as_tensordict(
q1=self.q1(inputs),
q2=self.q2(inputs),
)
if module == "q-target":
return self.as_tensordict(
q1=self.q1_target(inputs),
q2=self.q2_target(inputs),
)
raise ValueError(f"unknown module {module}") So my question is whether this is even necessary (in the top-level module) or whether it is sufficient to simply call the |
Yes.
By step methods we mean the ones from LightningModule, namely: training_step, validation_step, test_step, predict_step. If a method is named like this, it will be treated specially and redirected through "forward".
No it does not, and likely will never, because 1) it adds some overhead and 2) we should stay as close as possible to PyTorch's behavior to avoid potential confusion or unexpected behavior.
Yes it is necessar. And it is not sufficient to only setup/wrap the top level module while calling the submodules from the outside, as this would bypass the ddp-forward (this PR introduced a warning message if you do this). If you don't like this, an alternative would be to set up each submodule individually, and get rid of the top-level module. You could inline the logic directly in the Fabric code. module0 = fabric.setup(module0, ...)
module1 = fabric.setup(module1, ...)
module2 = fabric.setup(module2, ...)
# etc. This will prepare each module with DDP and then you can call their forward individually, and DDP will work correctly this way. Whether this or the other way around is more convenient depends on the use case. Let me know what you think about this. |
Thanks for clarifying this. I will reconsider my current usage of fabric. I ended up thinking about it Let's say I have the following loop: while current_steps_per_epoch < self.env_steps_per_epoch:
trajectory = self.rollout(train_environment, agent, episode, reward_modules, is_warmup)
agent.buffer.put(trajectory)
for _ in range(len(trajectory)):
agent.training_step(optimizers)
current_steps_per_epoch += 1
if current_steps_per_epoch == self.env_steps_per_epoch:
break Basically What I am doing there is more or less analogous TL;DR:In general, we should use the Option 1Option 1 is to build moduleA = make_model()
moduleB = make_model()
optimizerA = Optimizer(moduleA.parameters())
optimizerB = Optimizer(moduleB.parameters())
agent = Agent(moduleA, moduleB)
agent = fabric.setup(Agent, optimizerA, optimizerB)
def training_step(self, optimizers):
batch = self.buffer.get(batch_size=self.batch_size)
opt_A, opt_B = optimizers["A"], optimizers["B"]
loss_A, loss_B = self.forward(batch)
opt_A.zero_grad()
self.fabric.backward(loss_A)
opt_A.step()
opt_B.zero_grad()
self.fabric.backward(loss_B)
opt_B.step() Are there any implications with using one optimizer with multiple parameter groups? In some cases, it's optimizer = Optimizer([
dict(params=moduleA.parameters(), lr=1e-3),
dict(params=moduleB.parameters(), lr=1e-4)
])
agent = fabric.setup(Agent, optimizer) def training_step(self, optimizer):
batch = self.buffer.get(batch_size=self.batch_size)
loss_A, loss_B = self.forward(batch)
optimizer.zero_grad()
self.fabric.backward(loss_A)
optimizer.step()
optimizer.zero_grad()
self.fabric.backward(loss_B)
optimizer.step() Option 2Option 2 is to build moduleA = make_model()
moduleB = make_model()
optimizerA = Optimizer(moduleA.parameters())
optimizerB = Optimizer(moduleB.parameters())
moduleA, optimizerA = fabric.setup(moduleA, optimizerA)
moduleB, optimizerB = fabric.setup(moduleB, optimizerB)
agent = Agent(moduleA, moduleB)
def training_step(self, optimizers):
batch = self.buffer.get(batch_size=self.batch_size)
opt_A, opt_B = optimizers["A"], optimizers["B"]
loss_A = self.moduleA.forward(batch)
loss_B = self.moduleB.forward(batch)
opt_A.zero_grad()
self.fabric.backward(loss_A)
opt_A.step()
opt_B.zero_grad()
self.fabric.backward(loss_B)
opt_B.step() Pros and ConsOption 1 seems more natural for regular, non-distributed cases I would say. I always thought of Going through the above process, I find option 2 much more appealing in my case. If you need features QuestionsIf I have pairs of identical modules (like target networks in RL), and I only copy weights from the From the same setup, i.e., non- So thanks again and please let me know if I missed the point somewhere. |
Hi, I wonder is it correct (or safe) to call LightingModule's model hooks in Fabric? Like |
Computation would be fine, just if you return a tensor and want to call backward on it, DDP won't reduce the gradients. |
What does this PR do?
Fixes #16856
Context:
In PyTorch, if you wrap your module into e.g. DistributedDataParallel
everything that you want to pass through that model needs to go through
ddp_model.forward
! Like so:All of the following would silently work but are incorrect, as it would render DDP completely useless:
However, especially the usage in the last statement is what we want the user experience to be when working with a LightningModule in a Fabric training loop. Hence, we need to somehow simulate that
model.training_step(batch, batch_idx)
behaves likeddp_model(batch, batch_idx)
without the user having to change anything in their code.This PR does exactly that using black magic 300 IQ big brain logic. The code for this solution is a bit complicated to understand, but it is concise and confined to a single place in the _FabricModule. Alternatives we considered would be brittle and spread all over the strategies.
cc @Borda @carmocca @justusschock @awaelchli