Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fairscale checkpoint-wrapper deconstructs NamedTuple outputs #954

Open
rceballos98 opened this issue Mar 9, 2022 · 3 comments
Open

Fairscale checkpoint-wrapper deconstructs NamedTuple outputs #954

rceballos98 opened this issue Mar 9, 2022 · 3 comments

Comments

@rceballos98
Copy link

rceballos98 commented Mar 9, 2022

It seems checkpoint_wrapper deconstructs NamedTuple into just a Tuple. This makes it difficult to do type checks for interfaces between models.

Here is a sample test to reproduce:

from typing import NamedTuple
from fairscale.nn import checkpoint_wrapper
from torch import Tensor
from torch.nn import (
    GELU,
    Conv2d,
    Flatten,
    Linear,
    MaxPool2d,
    Module,
    Sequential,
)


class SimpleModelOutput(NamedTuple):
    intermidiate: Tensor
    final: Tensor


class SimpleModel(Module):
    def __init__(
        self,
        output_dim: int,
    ):
        super().__init__()

        self.output_dim = output_dim
        self.model_1 = Sequential(
            Conv2d(3, 20, kernel_size=5),
            GELU(),
            MaxPool2d(2, stride=2),
            Conv2d(20, 50, kernel_size=5),
            GELU(),
        )

        self.model_2 = Sequential(
            MaxPool2d(2, stride=2),
            Flatten(),
            Linear(50 * 4 * 4, output_dim),
            GELU(),
        )

    def forward(self, x: Tensor) -> SimpleModelOutput:
        intermidiate: Tensor = self.model_1(x)
        final: Tensor = self.model_2(intermidiate)

        return SimpleModelOutput(intermidiate=intermidiate, final=final)


def test_return_type_for_checkpoint_wrapper():
    model_1 = SimpleModel(500)

    input_data = torch.rand(((1, 3, 28, 28)))
    out1 = model_1(input_data)

    model_2 = checkpoint_wrapper(model_1)
    out2 = model_2(input_data)

    # checkpoint_wrapper seems to deconstruct the output type of this model into a tuple
    assert type(out1) == type(out2), f'Non-matching types: {type(out1)=} != {type(out2)=}'

Here is the output:

image

@rceballos98
Copy link
Author

rceballos98 commented Mar 9, 2022

tagging @gshaikov

@rceballos98
Copy link
Author

accidentally closed 😅

@rceballos98 rceballos98 reopened this Mar 9, 2022
@min-xu-ai
Copy link
Contributor

Thanks for the excellent issue report @rceballos98. I suspect that the issue is in container.py which deals with different types. Is this only a problem for your type checking or there are issues beyond just checking?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants