Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower accuracy for a resnet-18 model with TVM #8679

Closed
crawlingcub opened this issue Aug 6, 2021 · 12 comments · Fixed by #8699
Closed

Lower accuracy for a resnet-18 model with TVM #8679

crawlingcub opened this issue Aug 6, 2021 · 12 comments · Fixed by #8699
Assignees

Comments

@crawlingcub
Copy link
Contributor

Hi,

I am getting lower accuracy with TVM when targeting both cuda and cpu as compared to running with a pytorch model. This is a variant of a Resnet-18 model. Find the link to download the model below.

You will have to download the imagenet validation dataset and extract/sort it into a folder. Replace imagenet/data with the name of that folder.

You can download the model from here, untar, and pass the path to the script below.

Environment:

TVM installed from source
Pytorch 1.8.1
Python 3.7
OS: Ubuntu 18.04
Cuda 11.1
GPUs: 8 NVidia GA100

Code:

import torch
import metrics
from torch.utils.data import DataLoader
#from fuzzer.datasets.ImageNetDataset import ImageNetDataset
import sys
#from training_utils import eval_model_vision, eval_model_tvm
from torchvision import datasets
from torchvision.transforms import transforms
import numpy as np

def eval_model_tvm(model, dataset, device):
    import tvm
    from tvm import relay
    from tvm.contrib.download import download_testdata
    from tvm.contrib import graph_executor
    import logging
    logger = logging.getLogger('compile_engine')
    logger.setLevel(logging.ERROR)

    validation_dataloader = DataLoader(dataset, batch_size=100, shuffle=False)
    if "cpu" in device.lower():
        target = tvm.target.Target("llvm", host="llvm")
    else:
        target = tvm.target.cuda()
    print("target", target)
    dev = tvm.device(str(target))
    model = model.to("cpu")
    model.eval()
    mod = None
    lib = None
    acc1s = []
    acc5s = []
    for i, (images, targets) in enumerate(validation_dataloader):
        print(i)
        input_name = "input0"
        if mod is None:
            scripted_model = torch.jit.trace(model, images).eval()
            print("scripted")
            input_data = np.array([images[i].data.numpy() for i in range(len(images))], dtype="float32")
            shape_list = [(input_name, input_data.shape)]
            mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

            with tvm.transform.PassContext(opt_level=3):
                lib = relay.build(mod, target=target, params=params)

        m = graph_executor.GraphModule(lib["default"](dev))
        m.set_input(input_name, tvm.nd.array(images))
        m.run()
        output = m.get_output(0).numpy()
        acc1, acc5 = metrics.accuracy(torch.tensor(output), targets, topk=(1, 5))
        print("Batch {0}, acc1: {1} acc5: {2}".format(i, acc1, acc5))
        acc1s.append(acc1)
        acc5s.append(acc5)


    return {'acc1': np.mean(acc1s), 'acc5': np.mean(acc5s)}

def eval_model_vision(model, dataset, device, criterion, compute_metrics_fn):
    print("Running validation...")
    if not isinstance(model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)
    if not isinstance(dataset, DataLoader):
        validation_dataloader = DataLoader(dataset, batch_size=100, shuffle=True)
    else:
        validation_dataloader = dataset
    acc1s = []
    acc2s = []
    model.to(device)
    model.eval()
    print("Val size ", len(validation_dataloader))

    with torch.no_grad():
        for i, (images, target) in enumerate(validation_dataloader):
            # compute output
            images = images.to(device)
            target = target.to(device)

            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = compute_metrics_fn(output, target, topk=(1, 5))
            acc1s.append(acc1.item())
            acc2s.append(acc5.item())
            if i % 10 == 0:
                print(i, loss)

    return {'acc1': np.mean(acc1s), 'acc5': np.mean(acc2s)}

def load_dataset():
    tr = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    dataset = datasets.ImageNet("./imagenet/data", split="val", transform=tr)
    dataset = torch.utils.data.Subset(dataset, range(10000))
    split_datasets = dict()
    start = 0
    splits = [("train", 0.70), ("val", 0.10), ("test", 0.10)]
    for split in splits:
        indices = range(start, int(split[1]*len(dataset)) + start)
        split_datasets[split[0]] = torch.utils.data.Subset(dataset, indices)
        start = indices[-1] + 1
    dataset = split_datasets
    return dataset

model = torch.load(sys.argv[1]+'/model.pt')
dataset = load_dataset()
DEVICE="cuda"
res1 = eval_model_vision(model, dataset["val"], device=DEVICE, criterion=torch.nn.CrossEntropyLoss(),     compute_metrics_fn=metrics.accuracy)
print(res1)

res2= eval_model_tvm(model, dataset["val"] , DEVICE)
print(res2)

Output:

Running validation...
Val size  10
0 tensor(12.4782, device='cuda:0')
{'acc1': 40.2, 'acc5': 63.9}
target cuda -keys=cuda,gpu -max_num_threads=1024 -model=unknown -thread_warp_size=32
0
scripted
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
{'acc1': 7.5, 'acc5': 21.7}

Please let me know if you need more info. Thanks!

@masahi
Copy link
Member

masahi commented Aug 6, 2021

PyTorch resnet 18 is tested on every CI job,

. So I don't expect any accuracy difference.

Can you try evaluating the model that is not serialized to disk? When PyTorch jit models are serialized, PyTorch erase all type information. This issue has caused some problems for us in the past.

@crawlingcub
Copy link
Contributor Author

Hi, just to clarify, this is a mutant derived from a resnet-18 model, so the model structure is a bit different. We are testing the behavior of tvm when running some simple variants of well-tested models like resnet.

Can you try evaluating the model that is not serialized to disk?

What do you mean by this? The results are same as before I serialized this model to disk.

@masahi
Copy link
Member

masahi commented Aug 7, 2021

I meant instead of model = torch.load(sys.argv[1]+'/model.pt'), create a model directly from a python script. But if you already tried that, something is a bit off indeed.

@masahi
Copy link
Member

masahi commented Aug 7, 2021

Can you also try exporting to ONNX and try our ONNX frontend? That would tell if this is a frontend specific issue.

@crawlingcub
Copy link
Contributor Author

Ok, I will try that out

@crawlingcub
Copy link
Contributor Author

Hi,

I tried exporting original model to onnx and then running with TVM's ONNX frontend. The results are accurate with onnx: actually exactly similar to what I get with pytorch. So this seems like a bug in the pytorch frontend?

@masahi
Copy link
Member

masahi commented Aug 9, 2021

Ok. Can you send me an ONNX file, and if possible the pytorch model source?

@masahi masahi self-assigned this Aug 9, 2021
@crawlingcub
Copy link
Contributor Author

Hi,

I have updated the link above to include both the pytorch and onnx model. Regarding the model source, I used the pretrained Resnet-18 model from torchvision and applied some simple mutations on top of it such adding noise to some weights, replacing activation function, and adding a new layer. I can send you a model summary if needed.

@masahi
Copy link
Member

masahi commented Aug 10, 2021

I can confirm that model.pt and a TVM model converted via ONNX give the same output. It is hard to compare two TVM models, one coming from PT frontend and the other from ONNX, since the ONNX model folds batch norm into convolution so there is no batch norm in ONNX model.

Since the differences from resnet18 seem small, and we know that there is no issue with resnet18, maybe you can start with resnet18 and gradually add your changes until the TVM result becomes off? I'm using a script modified from https://github.com/apache/tvm/blob/main/tutorials/frontend/from_pytorch.py to test the accuracy, and there is no need to train the model.

@crawlingcub
Copy link
Contributor Author

I looked at the trace of changes. It seems replacing one ReLU activation with ELU introduced a big change in accuracy. All changes before that did not affect the results much. Maybe some issue with ELU implementation in pytorch frontend?

@masahi
Copy link
Member

masahi commented Aug 10, 2021

Bingo! Fixed in #8699

Thanks for reporting.

@crawlingcub
Copy link
Contributor Author

Awesome! Thanks for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants