Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient error on GPU for subsequent Conv2d utilization after downsampling. #662

Open
KyleM-Irreversible opened this issue Jun 19, 2024 · 5 comments
Assignees
Labels
bug Something isn't working status:reviewing

Comments

@KyleM-Irreversible
Copy link

KyleM-Irreversible commented Jun 19, 2024

Description

I have a fairly simple convolutional neural network with two distinct convolutional layers. I would like to create a three layer convolutional network by applying the second convolutional layer two times, downsampling between the layers (using AvgPool2d). Here is a diagram of my network architecture:

aihwkit_bug drawio

(Note the two yellow layers are the same "layer", just applied two times. This is done to reduce the number of parameters.)

When I convert my model to analog using convert_to_analog(), it works fine in the forward pass but gives me the following error upon calling .backward():

RuntimeError: Function AnalogFunctionBackward returned an invalid gradient at index 1 - got [256, 16, 16, 16] but expected shape compatible with [256, 16, 32, 32]

This does not occur on CPU, only GPU. Also, the original "digital" model works fine on both GPU and CPU. If I remove the "downsampling" layer (i.e. remove the AvgPool2d between the two convolutional layers), it works in all cases.

How to reproduce

Here is a minimum working example:

import torch
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import aihwkit
from aihwkit.simulator.configs import InferenceRPUConfig
from aihwkit.nn.conversion import convert_to_analog


REPRODUCE_BUG = True
DEVICE = "cuda"

class SimpleCNN_w_reuse(torch.nn.Module):
    def __init__(self, device):             
        super().__init__()
        
        self.conv_layers = torch.nn.ModuleList()
        self.act = torch.nn.ReLU()
        out_size = 32

        #Define two convolutional layers. The second one will be used twice.
        self.conv_layers.append(torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1))
        self.conv_layers.append(torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1))

        if REPRODUCE_BUG:
            self.downsample = torch.nn.AvgPool2d(2)
            out_size = out_size // 2

        out_ch = 16
        self.global_pool = torch.nn.AvgPool2d(out_size)

        self.fc_layers = torch.nn.ModuleList()
        #GlobalMaxPool
        self.fc_layers.append(torch.nn.Linear(out_ch, 10))

    def forward(self, batch):
        h = batch

        #first convolution 3 channels --> 16 channels
        h = self.act(self.conv_layers[0](h))

        #second convolution 16 channels --> 16 channels
        h = self.act(self.conv_layers[1](h))

        #If we want to reproduce the bug, we downsample the tensor.
        if REPRODUCE_BUG:
            h = self.downsample(h)
        
        #third convolution REUSES THE SAME LAYER
        #the bug appears if the input is a different size the second time around
        h = self.act(self.conv_layers[1](h))

        h = self.global_pool(h)
        h = torch.flatten(h, start_dim=1)
        h = self.fc_layers[0](h)
        
        return h


def train():

    torch.autograd.set_detect_anomaly(True)
    device = DEVICE

    # Data
    train_ds = CIFAR10("./cifar10_files", train=True, transform=transforms.ToTensor(), download=True)
    train_dl = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=1, drop_last=True)

    d_network = SimpleCNN_w_reuse(device)

    #CONVERT TO ANALOG:
    rpu_config = InferenceRPUConfig()
    a_network = convert_to_analog(d_network, rpu_config=rpu_config)  
    
    loss_fn = torch.nn.CrossEntropyLoss()


    a_optimizer = aihwkit.optim.AnalogAdam(a_network.parameters(), lr=1e-3)
    d_optimizer = torch.optim.Adam(d_network.parameters(), lr=1e-3)
    
    d_network.to(device)
    d_network.train()
    a_network.to(device)
    a_network.train()

    for epoch in range(100):
        for it, (batch, label) in enumerate(train_dl):
            batch = batch.to(device)
            label = label.to(device)
            print(f"Epoch {epoch+it/len(train_dl)},", end="")


            d_optimizer.zero_grad()
            logits = d_network(batch)
            d_loss = loss_fn(logits, label)
            d_loss.backward()
            d_optimizer.step()
            print(f"D_loss: {d_loss.item()},", end="")

            a_optimizer.zero_grad()
            logits = a_network(batch)
            a_loss = loss_fn(logits, label)
            a_loss.backward()
            a_optimizer.step()
            print(f"A_loss: {a_loss.item()},", end="")

            print("")
            


if __name__ == "__main__":
    train()

Expected behavior

The above example should run and train both the analog and digital versions of the model.

Other information

  • Pytorch version: 2.0.1+cu117
  • Package version: 0.8.0
  • OS: Ubuntu 22.04
  • Python version: 3.10.14
  • Conda version (or N/A): 23.3.1
@KyleM-Irreversible KyleM-Irreversible added the bug Something isn't working label Jun 19, 2024
@kaoutar55
Copy link
Collaborator

Thank you for reporing this issue. We will try to reproduce this and fix it.

@jubueche
Copy link
Collaborator

Can you try: TorchInferenceRPUConfig instead of InferenceRPUConfig?

@maljoras
Copy link
Collaborator

Hi @KyleM-Irreversible ,
thanks for reporting this. Indeed, re-using the same layer for two different sized inputs is currently not supported. You can try to use the TorchInferenceRPUConfig (@jubueche suggested), which implements a subset of features of the InferenceRPUConfig purely in torch instead of relying on the RPUCuda library. It might work in the case of re-using a layer with different sizes as it computed the backward pass differently.

@kaoutar55
Copy link
Collaborator

@KyleM-Irreversible have you tried the new ways that @maljoras and @jubueche recommended? Please let us know soon. Thanks!

@PabloCarmona
Copy link
Collaborator

Hello @KyleM-Irreversible! Any updates on this? Do you need any further help? Please let us know to close the issue, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working status:reviewing
Projects
None yet
Development

No branches or pull requests

7 participants