Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LightningModule step methods bypassing DDP wrapper in Fabric #17424

Merged
merged 26 commits into from
Apr 21, 2023

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Apr 20, 2023

What does this PR do?

Fixes #16856

Context:
In PyTorch, if you wrap your module into e.g. DistributedDataParallel

model = ...
ddp_model = torch.nn.parallel.DistributedDataParallel(model, ...)

everything that you want to pass through that model needs to go through ddp_model.forward! Like so:

out = ddp_model(batch)

All of the following would silently work but are incorrect, as it would render DDP completely useless:

out = model(batch)
out = ddp_model.module(batch)
out = model.training_step(batch, batch_idx)

However, especially the usage in the last statement is what we want the user experience to be when working with a LightningModule in a Fabric training loop. Hence, we need to somehow simulate that model.training_step(batch, batch_idx) behaves like ddp_model(batch, batch_idx) without the user having to change anything in their code.

This PR does exactly that using black magic 300 IQ big brain logic. The code for this solution is a bit complicated to understand, but it is concise and confined to a single place in the _FabricModule. Alternatives we considered would be brittle and spread all over the strategies.

cc @Borda @carmocca @justusschock @awaelchli

@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels Apr 20, 2023
@awaelchli awaelchli added the bug Something isn't working label Apr 20, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Apr 20, 2023
@github-actions github-actions bot removed the pl Generic label for PyTorch Lightning package label Apr 20, 2023
@awaelchli awaelchli changed the title WIP: Fix LightningModule step methods bypassing DDP wrapper in Fabric Fix LightningModule step methods bypassing DDP wrapper in Fabric Apr 21, 2023
@awaelchli awaelchli marked this pull request as ready for review April 21, 2023 10:37
@github-actions
Copy link
Contributor

github-actions bot commented Apr 21, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.11) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
pl-cpu (windows-2022, lightning, 3.8, 1.11) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success

These checks are required after the changes to src/lightning/fabric/wrappers.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) success

These checks are required after the changes to src/lightning/fabric/wrappers.py.

🟢 fabric: Docs
Check ID Status
make-doctest (fabric) success
make-html (fabric) success

These checks are required after the changes to src/lightning/fabric/wrappers.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.11) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.11) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
fabric-cpu (windows-2022, lightning, 3.8, 1.11) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success

These checks are required after the changes to src/lightning/fabric/wrappers.py, tests/tests_fabric/test_wrappers.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) success

These checks are required after the changes to src/lightning/fabric/wrappers.py, tests/tests_fabric/test_wrappers.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/wrappers.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.10) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.10) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.10) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.10) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.10) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.10) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.10) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.10) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.10) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.10) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.10) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.10) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.10) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.10) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.10) success

These checks are required after the changes to src/lightning/fabric/wrappers.py.

🟢 link-check
Check ID Status
check-md-links / markdown-link-check success

These checks are required after the changes to src/lightning/fabric/CHANGELOG.md.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

src/lightning/fabric/wrappers.py Outdated Show resolved Hide resolved
src/lightning/fabric/wrappers.py Show resolved Hide resolved
@mergify mergify bot added the ready PRs ready to be merged label Apr 21, 2023
@awaelchli awaelchli merged commit 0ee71d6 into master Apr 21, 2023
@awaelchli awaelchli deleted the debug/ddp-wrapper branch April 21, 2023 19:29
Borda pushed a commit that referenced this pull request Apr 24, 2023
)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit 0ee71d6)
Borda pushed a commit that referenced this pull request Apr 24, 2023
)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit 0ee71d6)
Borda pushed a commit that referenced this pull request Apr 24, 2023
)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit 0ee71d6)
lantiga pushed a commit that referenced this pull request Apr 24, 2023
)

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit 0ee71d6)
warning_cache.warn(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
f" `.backward()`. You should pass your inputs through `{type(self._original_module)}.forward()`.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was __name__ missed here? If so, can you include the change in #17531?


def _validate_method_access(self, name: str, attribute: Any) -> None:
if inspect.ismethod(attribute) and self._forward_module != self._original_module:
warning_cache.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed we get some false positives from this in CI:

/__w/2/s/src/lightning_fabric/wrappers.py:158: PossibleUserWarning: You are calling the method `ConvNet.get_dataloader()` from outside the model. This will bypass the wrapper from the strategy and result in incorrect behavior in `.backward()`. You should pass your inputs through `<class 'tests_fabric.parity.models.ConvNet'>.forward()`.
  warning_cache.warn(
/__w/2/s/src/lightning_fabric/wrappers.py:158: PossibleUserWarning: You are calling the method `ConvNet.get_loss_function()` from outside the model. This will bypass the wrapper from the strategy and result in incorrect behavior in `.backward()`. You should pass your inputs through `<class 'tests_fabric.parity.models.ConvNet'>.forward()`.
  warning_cache.warn(

The pattern used in the test seems natural. I'm thinking if we should relax this or reword it to make it clearer that this could be okay if you aren't running a computation inside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking the same. Sometimes it is ok, sometimes not.
Additionally, to reduce the message spam if lots of different function calls happen, we could show the warning only for the first use but not for others.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might make it useless, as often the first call will be for something that's not a problem like get_dataloader above.

I think the simplest improvement would be to clarify the warning wording so that its easier to understand that this might be a false positive. "You should pass your inputs..." makes it sound like you always need to do this.

@patchmeifyoucan
Copy link

patchmeifyoucan commented May 22, 2023

Hello, I am trying to understand how exactly to setup Fabric to train properly with DDP. I am writing code for distributed reinforcement learning and as such the modules contain arbitrary submodules. In general, there are two types of calls I make to the model: 1) interacting with the environment, which does not require gradients and 2) the optimization of the modules, which does require gradients.

For case 1) I assume that everything works as intended, even though I had to do some manual device placement after upgrading from fabric 2.0.0 to 2.0.2. Before that, the same code worked without these changes.

My main question is about the case 2). As far as I understand, some calls are now redirected to a fake .forward() method, probably to sync gradients during backward, right? Since the title of this issues says "step methods", is there any naming pattern to be followed? Or does it simply wrap all calls?

The source of my confusion is that in the models I am implementing, there is no natural, single .forward() in the top-level module but rather calls to submodules, which themselves do have a natural forward method.

I have a training loop, which, after some interactions, calls this .optimize() method (how many submodules are optimized depends on the algorithm). My question now is whether there are any issues with this approach.

def optimize(self, optimizers) -> TensorDict:
    self.step += 1

    self.alpha = self.alpha.to(self.device)

    value_logs = self.optimize_value(optimizers["value"])
    policy_logs = self.optimize_policy(optimizers["policy"])

    return self.as_tensordict(**policy_logs, **value_logs)

Because I am not exactly sure which pattern is correct, I started redirecting calls to the submodules via the code below. The downside is that this makes everything more ugly, verbose and error-prone, so I would like to avoid having this.

def forward(self, inputs, module) -> TensorDict:
    if module == "policy":
        return self.action(inputs)

    if module == "q":
        return self.as_tensordict(
            q1=self.q1(inputs),
            q2=self.q2(inputs),
        )

    if module == "q-target":
        return self.as_tensordict(
            q1=self.q1_target(inputs),
            q2=self.q2_target(inputs),
        )

    raise ValueError(f"unknown module {module}")

So my question is whether this is even necessary (in the top-level module) or whether it is sufficient to simply call the forward method of the submodules from the top-level module. Also, do I have to follow any naming conventions (e.g., _step methods) such that the redirection is applied correctly?

@awaelchli
Copy link
Contributor Author

awaelchli commented May 28, 2023

My main question is about the case 2). As far as I understand, some calls are now redirected to a fake .forward() method, probably to sync gradients during backward, right?

Yes.

Since the title of this issues says "step methods", is there any naming pattern to be followed? Or does it simply wrap all
calls?

By step methods we mean the ones from LightningModule, namely: training_step, validation_step, test_step, predict_step. If a method is named like this, it will be treated specially and redirected through "forward".

Or does it simply wrap all calls?

No it does not, and likely will never, because 1) it adds some overhead and 2) we should stay as close as possible to PyTorch's behavior to avoid potential confusion or unexpected behavior.

So my question is whether this is even necessary (in the top-level module) or whether it is sufficient to simply call the forward method of the submodules from the top-level module.

Yes it is necessar. And it is not sufficient to only setup/wrap the top level module while calling the submodules from the outside, as this would bypass the ddp-forward (this PR introduced a warning message if you do this).

If you don't like this, an alternative would be to set up each submodule individually, and get rid of the top-level module. You could inline the logic directly in the Fabric code.

module0 = fabric.setup(module0, ...)
module1 = fabric.setup(module1, ...)
module2 = fabric.setup(module2, ...)
# etc.

This will prepare each module with DDP and then you can call their forward individually, and DDP will work correctly this way. Whether this or the other way around is more convenient depends on the use case. Let me know what you think about this.

@patchmeifyoucan
Copy link

Thanks for clarifying this. I will reconsider my current usage of fabric. I ended up thinking about it
longer than I thought and I also would like to summarize it in case someone else stumbles across it.

Let's say I have the following loop:

while current_steps_per_epoch < self.env_steps_per_epoch:
    trajectory = self.rollout(train_environment, agent, episode, reward_modules, is_warmup)
    agent.buffer.put(trajectory)

    for _ in range(len(trajectory)):
        agent.training_step(optimizers)
        current_steps_per_epoch += 1
        if current_steps_per_epoch == self.env_steps_per_epoch:
            break

Basically agent.training_step() does the same as my above implementation with agent.optimize(), so it's nice that
these calls work now. There, I can then optimize different modules at different frequencies with different batch sizes.

What I am doing there is more or less analogous
to manual optimization. So given your
explanation above, I should implement it using one of these options. Let's say we have class (not necessarily
a LightningModule as I understand) Agent
with submodules moduleA and moduleB both with optimizers optimizerA and optimizerB. Then there are 2 options.

TL;DR:

In general, we should use the forward() or the LightningModule we set up using fabric.

Option 1

Option 1 is to build moduleA and moduleB as LightningModules using regular PyTorch, then set them as submodules in
the Agent constructor. Agent in this case is a LightningModule which we set up using fabric.

moduleA = make_model()
moduleB = make_model()

optimizerA = Optimizer(moduleA.parameters())
optimizerB = Optimizer(moduleB.parameters())

agent = Agent(moduleA, moduleB)
agent = fabric.setup(Agent, optimizerA, optimizerB)

training_step() then is:

def training_step(self, optimizers):
    batch = self.buffer.get(batch_size=self.batch_size)
    opt_A, opt_B = optimizers["A"], optimizers["B"]
    loss_A, loss_B = self.forward(batch)

    opt_A.zero_grad()
    self.fabric.backward(loss_A)
    opt_A.step()

    opt_B.zero_grad()
    self.fabric.backward(loss_B)
    opt_B.step()

Are there any implications with using one optimizer with multiple parameter groups? In some cases, it's
convenient to define a scalar parameter as "group" (like alpha in SAC). One could hack around it by wrapping
it in a module, but that's quite some overkill for one scalar.

optimizer = Optimizer([
    dict(params=moduleA.parameters(), lr=1e-3),
    dict(params=moduleB.parameters(), lr=1e-4)
])
agent = fabric.setup(Agent, optimizer)
def training_step(self, optimizer):
    batch = self.buffer.get(batch_size=self.batch_size)
    loss_A, loss_B = self.forward(batch)

    optimizer.zero_grad()
    self.fabric.backward(loss_A)
    optimizer.step()

    optimizer.zero_grad()
    self.fabric.backward(loss_B)
    optimizer.step()

Option 2

Option 2 is to build moduleA and moduleB as LightningModules using regular PyTorch, but set them up using
fabric first
and then set them as submodules in the Agent constructor. Agent in this case is NOT
a LightningModule and hence we also do not set it up using fabric and also don't use its forward().

moduleA = make_model()
moduleB = make_model()

optimizerA = Optimizer(moduleA.parameters())
optimizerB = Optimizer(moduleB.parameters())

moduleA, optimizerA = fabric.setup(moduleA, optimizerA)
moduleB, optimizerB = fabric.setup(moduleB, optimizerB)

agent = Agent(moduleA, moduleB)

training_step() then is:

def training_step(self, optimizers):
    batch = self.buffer.get(batch_size=self.batch_size)
    opt_A, opt_B = optimizers["A"], optimizers["B"]

    loss_A = self.moduleA.forward(batch)
    loss_B = self.moduleB.forward(batch)

    opt_A.zero_grad()
    self.fabric.backward(loss_A)
    opt_A.step()

    opt_B.zero_grad()
    self.fabric.backward(loss_B)
    opt_B.step()

Pros and Cons

Option 1 seems more natural for regular, non-distributed cases I would say. I always thought of
a LightningModule as the top-level object you optimize, and I was first a bit confused to think
of it as a normal object not doing any magic.

Going through the above process, I find option 2 much more appealing in my case. If you need features
like different training frequencies and batch sizes for different submodules, putting everything into
the forward() function is a rather unrelated workaround simply for DDP to work then. If the operations are
related, like having a Double-Q network with two target networks, then it may make a little more sense,
even though I would probably still prefer two technically identical modules and building the semantics
into the optimization process. Option 1 also requires a dynamic computation graph (not an issue with more
modern frameworks but TensorFlow probably wouldn't like it either).
So if there is no need to couple inputs or outputs of different submodules, then the second option is the
way to go I would say.

Questions

If I have pairs of identical modules (like target networks in RL), and I only copy weights from the
model being optimized into the clone, I don't need to set it up and could in principle simply deepcopy
it in the constructor and this way avoid having this logic outside the module, right?

From the same setup, i.e., non-LightningModule top-level module with submodules being LightningModules,
I could again call the submodules' training_step, right? Could make sense if the models are independent
and have rather complicated optimization steps and one would like to keep things cleaned up.

So thanks again and please let me know if I missed the point somewhere.

@pableeto
Copy link

pableeto commented Aug 4, 2023

Hi, I wonder is it correct (or safe) to call LightingModule's model hooks in Fabric? Like
model.on_train_batch_start(batch, batch_idx)
Currently, this will lead to the warning that "You are calling the method from outside the model....". I am not sure whether I could ignore it? From the discussion above, it seems that it would be ok as long as I do not running a computation w.r.t the model inside?

@awaelchli
Copy link
Contributor Author

From the discussion above, it seems that it would be ok as long as I do not running a computation w.r.t the model inside?

Computation would be fine, just if you return a tensor and want to call backward on it, DDP won't reduce the gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Calls to LightningModule's *_step methods in Fabric bypasses DDP wrapper
5 participants