You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a fairly simple convolutional neural network with two distinct convolutional layers. I would like to create a three layer convolutional network by applying the second convolutional layer two times, downsampling between the layers (using AvgPool2d). Here is a diagram of my network architecture:
(Note the two yellow layers are the same "layer", just applied two times. This is done to reduce the number of parameters.)
When I convert my model to analog using convert_to_analog(), it works fine in the forward pass but gives me the following error upon calling .backward():
RuntimeError: Function AnalogFunctionBackward returned an invalid gradient at index 1 - got [256, 16, 16, 16] but expected shape compatible with [256, 16, 32, 32]
This does not occur on CPU, only GPU. Also, the original "digital" model works fine on both GPU and CPU. If I remove the "downsampling" layer (i.e. remove the AvgPool2d between the two convolutional layers), it works in all cases.
How to reproduce
Here is a minimum working example:
importtorchfromtorchvision.datasetsimportCIFAR10importtorchvision.transformsastransformsfromtorch.utils.dataimportDataLoaderimportaihwkitfromaihwkit.simulator.configsimportInferenceRPUConfigfromaihwkit.nn.conversionimportconvert_to_analogREPRODUCE_BUG=TrueDEVICE="cuda"classSimpleCNN_w_reuse(torch.nn.Module):
def__init__(self, device):
super().__init__()
self.conv_layers=torch.nn.ModuleList()
self.act=torch.nn.ReLU()
out_size=32#Define two convolutional layers. The second one will be used twice.self.conv_layers.append(torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1))
self.conv_layers.append(torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1))
ifREPRODUCE_BUG:
self.downsample=torch.nn.AvgPool2d(2)
out_size=out_size//2out_ch=16self.global_pool=torch.nn.AvgPool2d(out_size)
self.fc_layers=torch.nn.ModuleList()
#GlobalMaxPoolself.fc_layers.append(torch.nn.Linear(out_ch, 10))
defforward(self, batch):
h=batch#first convolution 3 channels --> 16 channelsh=self.act(self.conv_layers[0](h))
#second convolution 16 channels --> 16 channelsh=self.act(self.conv_layers[1](h))
#If we want to reproduce the bug, we downsample the tensor.ifREPRODUCE_BUG:
h=self.downsample(h)
#third convolution REUSES THE SAME LAYER#the bug appears if the input is a different size the second time aroundh=self.act(self.conv_layers[1](h))
h=self.global_pool(h)
h=torch.flatten(h, start_dim=1)
h=self.fc_layers[0](h)
returnhdeftrain():
torch.autograd.set_detect_anomaly(True)
device=DEVICE# Datatrain_ds=CIFAR10("./cifar10_files", train=True, transform=transforms.ToTensor(), download=True)
train_dl=DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=1, drop_last=True)
d_network=SimpleCNN_w_reuse(device)
#CONVERT TO ANALOG:rpu_config=InferenceRPUConfig()
a_network=convert_to_analog(d_network, rpu_config=rpu_config)
loss_fn=torch.nn.CrossEntropyLoss()
a_optimizer=aihwkit.optim.AnalogAdam(a_network.parameters(), lr=1e-3)
d_optimizer=torch.optim.Adam(d_network.parameters(), lr=1e-3)
d_network.to(device)
d_network.train()
a_network.to(device)
a_network.train()
forepochinrange(100):
forit, (batch, label) inenumerate(train_dl):
batch=batch.to(device)
label=label.to(device)
print(f"Epoch {epoch+it/len(train_dl)},", end="")
d_optimizer.zero_grad()
logits=d_network(batch)
d_loss=loss_fn(logits, label)
d_loss.backward()
d_optimizer.step()
print(f"D_loss: {d_loss.item()},", end="")
a_optimizer.zero_grad()
logits=a_network(batch)
a_loss=loss_fn(logits, label)
a_loss.backward()
a_optimizer.step()
print(f"A_loss: {a_loss.item()},", end="")
print("")
if__name__=="__main__":
train()
Expected behavior
The above example should run and train both the analog and digital versions of the model.
Other information
Pytorch version: 2.0.1+cu117
Package version: 0.8.0
OS: Ubuntu 22.04
Python version: 3.10.14
Conda version (or N/A): 23.3.1
The text was updated successfully, but these errors were encountered:
Hi @KyleM-Irreversible ,
thanks for reporting this. Indeed, re-using the same layer for two different sized inputs is currently not supported. You can try to use the TorchInferenceRPUConfig (@jubueche suggested), which implements a subset of features of the InferenceRPUConfig purely in torch instead of relying on the RPUCuda library. It might work in the case of re-using a layer with different sizes as it computed the backward pass differently.
Description
I have a fairly simple convolutional neural network with two distinct convolutional layers. I would like to create a three layer convolutional network by applying the second convolutional layer two times, downsampling between the layers (using AvgPool2d). Here is a diagram of my network architecture:
(Note the two yellow layers are the same "layer", just applied two times. This is done to reduce the number of parameters.)
When I convert my model to analog using
convert_to_analog()
, it works fine in the forward pass but gives me the following error upon calling.backward()
:RuntimeError: Function AnalogFunctionBackward returned an invalid gradient at index 1 - got [256, 16, 16, 16] but expected shape compatible with [256, 16, 32, 32]
This does not occur on CPU, only GPU. Also, the original "digital" model works fine on both GPU and CPU. If I remove the "downsampling" layer (i.e. remove the AvgPool2d between the two convolutional layers), it works in all cases.
How to reproduce
Here is a minimum working example:
Expected behavior
The above example should run and train both the analog and digital versions of the model.
Other information
The text was updated successfully, but these errors were encountered: