Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Fix KeyError: 'aten::reshape' on calling torch.reshape #3341

Merged
merged 1 commit into from
Feb 2, 2021

Conversation

tczhangzhi
Copy link
Contributor

NNI did not parse the aten::reshape operation, causing the following code to report an error when calling reshape() in forward:

import random

import nni.retiarii.nn.pytorch as nn
import torch.nn.functional as F
from nni.retiarii.experiment import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.strategies import RandomStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer


class Net(nn.Module):
    def __init__(self, hidden_size):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.LayerChoice([
            nn.Linear(4*4*50, hidden_size),
            nn.Linear(4*4*50, hidden_size, bias=False)
        ])
        self.fc2 = nn.Linear(hidden_size, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        =========== reshape start ===========
        x = x.reshape(-1, 4*4*50)
        =========== reshape end ===========

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


if __name__ == '__main__':
    base_model = Net(128)
    trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="MNIST",
                                                dataset_kwargs={"root": "data/mnist", "download": True},
                                                dataloader_kwargs={"batch_size": 32},
                                                optimizer_kwargs={"lr": 1e-3},
                                                trainer_kwargs={"max_epochs": 1})

    simple_startegy = RandomStrategy()

    exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'mnist_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.training_service.use_active_gpu = False

    exp.run(exp_config, 7081 + random.randint(0, 100))

The error message is as follows:

Traceback (most recent call last):
  File "test.py", line 50, in <module>
    exp.run(exp_config, 7081 + random.randint(0, 100))
  File "/home/zhangzhi/anaconda3/lib/python3.7/site-packages/nni/retiarii/experiment.py", line 173, in run
    super().run(port, debug)
  File "/home/zhangzhi/anaconda3/lib/python3.7/site-packages/nni/experiment/experiment.py", line 181, in run
    self.start(port, debug)
  File "/home/zhangzhi/anaconda3/lib/python3.7/site-packages/nni/retiarii/experiment.py", line 158, in start
    self._start_strategy()
  File "/home/zhangzhi/anaconda3/lib/python3.7/site-packages/nni/retiarii/experiment.py", line 124, in _start_strategy
    base_model_ir = convert_to_graph(script_module, self.base_model)
  File "/home/zhangzhi/anaconda3/lib/python3.7/site-packages/nni/retiarii/converter/graph_gen.py", line 526, in convert_to_graph
    GraphConverter().convert_module(script_module, module, module_name, model)
  File "/home/zhangzhi/anaconda3/lib/python3.7/site-packages/nni/retiarii/converter/graph_gen.py", line 487, in convert_module
    module_name, ir_model, ir_graph)
  File "/home/zhangzhi/anaconda3/lib/python3.7/site-packages/nni/retiarii/converter/graph_gen.py", line 351, in handle_graph_nodes
    handle_single_node(node)
  File "/home/zhangzhi/anaconda3/lib/python3.7/site-packages/nni/retiarii/converter/graph_gen.py", line 329, in handle_single_node
    aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
KeyError: 'aten::reshape'

@ghost
Copy link

ghost commented Jan 27, 2021

CLA assistant check
All CLA requirements met.

@J-shang J-shang requested a review from QuanluZhang January 29, 2021 08:39
@J-shang J-shang closed this Feb 1, 2021
@J-shang J-shang reopened this Feb 1, 2021
@QuanluZhang QuanluZhang merged commit be3a696 into microsoft:master Feb 2, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants