Skip to content

Commit ee3f7bc

Browse files
authored
[MSC][M5.3] Support torch.dynamo for dynamic models (#16772)
* add dynamic * add howto * update test
1 parent b91d4e5 commit ee3f7bc

File tree

38 files changed

+4072
-1605
lines changed

38 files changed

+4072
-1605
lines changed

gallery/how_to/work_with_msc/using_tools.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,12 @@
5757
parser.add_argument("--test_iter", type=int, default=100, help="The iter for test")
5858
parser.add_argument("--calibrate_iter", type=int, default=100, help="The iter for calibration")
5959
parser.add_argument("--train_batch", type=int, default=32, help="The batch size for train")
60-
parser.add_argument("--train_iter", type=int, default=200, help="The iter for train")
61-
parser.add_argument("--train_epoch", type=int, default=100, help="The epoch for train")
60+
parser.add_argument("--train_iter", type=int, default=100, help="The iter for train")
61+
parser.add_argument("--train_epoch", type=int, default=5, help="The epoch for train")
6262
parser.add_argument(
6363
"--verbose", type=str, default="info", help="The verbose level, info|debug:1,2,3|critical"
6464
)
65+
parser.add_argument("--dynamic", action="store_true", help="Whether to use dynamic wrapper")
6566
args = parser.parse_args()
6667

6768

@@ -88,8 +89,8 @@ def get_config(calib_loader, train_loader):
8889
compile_type=args.compile_type,
8990
dataset=dataset,
9091
tools=tools,
91-
skip_config={"all": "check"},
9292
verbose=args.verbose,
93+
dynamic=args.dynamic,
9394
)
9495

9596

@@ -100,13 +101,13 @@ def _get_calib_datas():
100101
for i, (inputs, _) in enumerate(testloader, 0):
101102
if i >= args.calibrate_iter > 0:
102103
break
103-
yield {"input": inputs}
104+
yield inputs if args.dynamic else {"input": inputs}
104105

105106
def _get_train_datas():
106107
for i, (inputs, _) in enumerate(trainloader, 0):
107108
if i >= args.train_iter > 0:
108109
break
109-
yield {"input": inputs}
110+
yield inputs if args.dynamic else {"input": inputs}
110111

111112
model = resnet50(pretrained=args.checkpoint)
112113
if torch.cuda.is_available():

python/tvm/contrib/msc/core/gym/environment/method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _get_loss(golden, result):
105105
outputs = runner.run(inputs)
106106
baseline = loader[idx]
107107
for name, data in outputs.items():
108-
loss += _get_loss(baseline[name], data)
108+
loss += _get_loss(baseline[name], msc_utils.cast_array(data))
109109
return {"loss": loss / len(loader)}
110110

111111
@classmethod

python/tvm/contrib/msc/core/gym/environment/quantize_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _summary(self, actions: List[dict], rewards: List[dict]) -> Union[dict, str]
7070
continue
7171
info.update(strategys[name].get_executor(msc_utils.MSCStage.QUANTIZE).config)
7272
summary_file = msc_utils.get_cache_dir().relpath("gym_summary.json")
73-
return msc_utils.dump_dict(plan, summary_file)
73+
return msc_utils.save_dict(plan, summary_file)
7474

7575
@classmethod
7676
def role_type(cls):

python/tvm/contrib/msc/core/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
"""tvm.contrib.msc.core.runtime"""
1818

1919
from .runner import *
20+
from .jit import *

0 commit comments

Comments
 (0)