Skip to content

Commit 73b138b

Browse files
authored
[Rust] Remove mxnet dependency and re-enable rust example (#17293)
* use torchvision's resnet18 instead of mxnet * re-enable rust example * update readme
1 parent 19b66bf commit 73b138b

File tree

5 files changed

+16
-27
lines changed

5 files changed

+16
-27
lines changed

rust/tvm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ You can find the API Documentation [here](https://tvm.apache.org/docs/api/rust/t
2626

2727
The goal of this crate is to provide bindings to both the TVM compiler and runtime
2828
APIs. First train your **Deep Learning** model using any major framework such as
29-
[PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.apache.org/) or [TensorFlow](https://www.tensorflow.org/).
29+
[PyTorch](https://pytorch.org/) or [TensorFlow](https://www.tensorflow.org/).
3030
Then use **TVM** to build and deploy optimized model artifacts on a supported devices such as CPU, GPU, OpenCL and specialized accelerators.
3131

3232
The Rust bindings are composed of a few crates:

rust/tvm/examples/resnet/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ This end-to-end example shows how to:
2121
* build `Resnet 18` with `tvm` from Python
2222
* use the provided Rust frontend API to test for an input image
2323

24-
To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
24+
To run the example with pretrained resnet weights, first `tvm` and `torchvision` must be installed for the python build. To install torchvision for cpu, run `pip install torch torchvision`
2525
and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html).
2626

2727
* **Build the example**: `cargo build

rust/tvm/examples/resnet/build.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ use anyhow::{Context, Result};
2121
use std::{io::Write, path::Path, process::Command};
2222

2323
fn main() -> Result<()> {
24-
// Currently disabled, as it depends on the no-longer-supported
25-
// mxnet repo to download resnet.
26-
27-
/*
2824
let out_dir = std::env::var("CARGO_MANIFEST_DIR")?;
2925
let python_script = concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py");
3026
let synset_txt = concat!(env!("CARGO_MANIFEST_DIR"), "/synset.txt");
@@ -57,7 +53,5 @@ fn main() -> Result<()> {
5753
);
5854
println!("cargo:rustc-link-search=native={}", out_dir);
5955

60-
*/
61-
6256
Ok(())
6357
}

rust/tvm/examples/resnet/src/build_resnet.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,18 @@
1717
# under the License.
1818

1919
import argparse
20-
import csv
2120
import logging
22-
from os import path as osp
23-
import sys
2421
import shutil
22+
from os import path as osp
2523

2624
import numpy as np
27-
25+
import torch
26+
import torchvision
2827
import tvm
29-
from tvm import te
30-
from tvm import relay, runtime
31-
from tvm.relay import testing
32-
from tvm.contrib import graph_executor, cc
3328
from PIL import Image
29+
from tvm import relay, runtime
30+
from tvm.contrib import cc, graph_executor
3431
from tvm.contrib.download import download_testdata
35-
from mxnet.gluon.model_zoo.vision import get_model
3632

3733
logging.basicConfig(
3834
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -64,11 +60,16 @@
6460

6561
def build(target_dir):
6662
"""Compiles resnet18 with TVM"""
67-
# Download the pretrained model in MxNet's format.
68-
block = get_model("resnet18_v1", pretrained=True)
63+
# Download the pretrained model from Torchvision.
64+
weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
65+
torch_model = torchvision.models.resnet18(weights=weights).eval()
66+
67+
input_shape = [1, 3, 224, 224]
68+
input_data = torch.randn(input_shape)
69+
scripted_model = torch.jit.trace(torch_model, input_data)
70+
input_infos = [("data", input_data.shape)]
71+
mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)
6972

70-
shape_dict = {"data": (1, 3, 224, 224)}
71-
mod, params = relay.frontend.from_mxnet(block, shape_dict)
7273
# Add softmax to do classification in last layer.
7374
func = mod["main"]
7475
func = relay.Function(
@@ -93,7 +94,6 @@ def build(target_dir):
9394

9495
def download_img_labels():
9596
"""Download an image and imagenet1k class labels for test"""
96-
from mxnet.gluon.utils import download
9797

9898
synset_url = "".join(
9999
[

rust/tvm/examples/resnet/src/main.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ use tvm_rt::graph_rt::GraphRt;
3131
use tvm_rt::*;
3232

3333
fn main() -> anyhow::Result<()> {
34-
// Currently disabled, as it depends on the no-longer-supported
35-
// mxnet repo to download resnet.
36-
37-
/*
3834
let dev = Device::cpu(0);
3935
println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png"));
4036

@@ -138,7 +134,6 @@ fn main() -> anyhow::Result<()> {
138134
"input image belongs to the class `{}` with probability {}",
139135
label, max_prob
140136
);
141-
*/
142137

143138
Ok(())
144139
}

0 commit comments

Comments
 (0)