Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

AMC supports resnet #2876

Merged
merged 100 commits into from
Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
3a45961
Merge pull request #31 from microsoft/master
chicm-ms Aug 6, 2019
633db43
Merge pull request #32 from microsoft/master
chicm-ms Sep 9, 2019
3e926f1
Merge pull request #33 from microsoft/master
chicm-ms Oct 8, 2019
f173789
Merge pull request #34 from microsoft/master
chicm-ms Oct 9, 2019
508850a
Merge pull request #35 from microsoft/master
chicm-ms Oct 9, 2019
5a0e9c9
Merge pull request #36 from microsoft/master
chicm-ms Oct 10, 2019
e7df061
Merge pull request #37 from microsoft/master
chicm-ms Oct 23, 2019
2175cef
Merge pull request #38 from microsoft/master
chicm-ms Oct 29, 2019
2ccbfbb
Merge pull request #39 from microsoft/master
chicm-ms Oct 30, 2019
b29cb0b
Merge pull request #40 from microsoft/master
chicm-ms Oct 30, 2019
4a3ba83
Merge pull request #41 from microsoft/master
chicm-ms Nov 4, 2019
c8a1148
Merge pull request #42 from microsoft/master
chicm-ms Nov 4, 2019
73c6101
Merge pull request #43 from microsoft/master
chicm-ms Nov 5, 2019
6a518a9
Merge pull request #44 from microsoft/master
chicm-ms Nov 11, 2019
a0d587f
Merge pull request #45 from microsoft/master
chicm-ms Nov 12, 2019
e905bfe
Merge pull request #46 from microsoft/master
chicm-ms Nov 14, 2019
4b266f3
Merge pull request #47 from microsoft/master
chicm-ms Nov 15, 2019
237ff4b
Merge pull request #48 from microsoft/master
chicm-ms Nov 21, 2019
682be01
Merge pull request #49 from microsoft/master
chicm-ms Nov 25, 2019
133af82
Merge pull request #50 from microsoft/master
chicm-ms Nov 25, 2019
71a8a25
Merge pull request #51 from microsoft/master
chicm-ms Nov 26, 2019
d2a73bc
Merge pull request #52 from microsoft/master
chicm-ms Nov 26, 2019
198cf5e
Merge pull request #53 from microsoft/master
chicm-ms Dec 5, 2019
cdbfaf9
Merge pull request #54 from microsoft/master
chicm-ms Dec 6, 2019
7e9b29e
Merge pull request #55 from microsoft/master
chicm-ms Dec 10, 2019
d00c46d
Merge pull request #56 from microsoft/master
chicm-ms Dec 10, 2019
de7d1fa
Merge pull request #57 from microsoft/master
chicm-ms Dec 11, 2019
1835ab0
Merge pull request #58 from microsoft/master
chicm-ms Dec 12, 2019
24fead6
Merge pull request #59 from microsoft/master
chicm-ms Dec 20, 2019
0b7321e
Merge pull request #60 from microsoft/master
chicm-ms Dec 23, 2019
60058d4
Merge pull request #61 from microsoft/master
chicm-ms Dec 23, 2019
b111a55
Merge pull request #62 from microsoft/master
chicm-ms Dec 24, 2019
611c337
Merge pull request #63 from microsoft/master
chicm-ms Dec 30, 2019
4a1f14a
Merge pull request #64 from microsoft/master
chicm-ms Jan 10, 2020
7a9e604
Merge pull request #65 from microsoft/master
chicm-ms Jan 14, 2020
b8035b0
Merge pull request #66 from microsoft/master
chicm-ms Feb 4, 2020
47567d3
Merge pull request #67 from microsoft/master
chicm-ms Feb 10, 2020
614d427
Merge pull request #68 from microsoft/master
chicm-ms Feb 10, 2020
a0d9ed6
Merge pull request #69 from microsoft/master
chicm-ms Feb 11, 2020
22dc1ad
Merge pull request #70 from microsoft/master
chicm-ms Feb 19, 2020
0856813
Merge pull request #71 from microsoft/master
chicm-ms Feb 22, 2020
9e97bed
Merge pull request #72 from microsoft/master
chicm-ms Feb 25, 2020
16a1b27
Merge pull request #73 from microsoft/master
chicm-ms Mar 3, 2020
e246633
Merge pull request #74 from microsoft/master
chicm-ms Mar 4, 2020
0439bc1
Merge pull request #75 from microsoft/master
chicm-ms Mar 17, 2020
8b5613a
Merge pull request #76 from microsoft/master
chicm-ms Mar 18, 2020
43e8d31
Merge pull request #77 from microsoft/master
chicm-ms Mar 22, 2020
aae448e
Merge pull request #78 from microsoft/master
chicm-ms Mar 25, 2020
7095716
Merge pull request #79 from microsoft/master
chicm-ms Mar 25, 2020
c51263a
Merge pull request #80 from microsoft/master
chicm-ms Apr 11, 2020
9953c70
Merge pull request #81 from microsoft/master
chicm-ms Apr 14, 2020
f9136c4
Merge pull request #82 from microsoft/master
chicm-ms Apr 16, 2020
b384ad2
Merge pull request #83 from microsoft/master
chicm-ms Apr 20, 2020
ff592dd
Merge pull request #84 from microsoft/master
chicm-ms May 12, 2020
0b5378f
Merge pull request #85 from microsoft/master
chicm-ms May 18, 2020
a53e0b0
Merge pull request #86 from microsoft/master
chicm-ms May 25, 2020
3ea0b89
Merge pull request #87 from microsoft/master
chicm-ms May 28, 2020
cf3fb20
Merge pull request #88 from microsoft/master
chicm-ms May 28, 2020
7f4cdcd
Merge pull request #89 from microsoft/master
chicm-ms Jun 4, 2020
574db2c
Merge pull request #90 from microsoft/master
chicm-ms Jun 15, 2020
32bedcc
Merge pull request #91 from microsoft/master
chicm-ms Jun 21, 2020
6155aa4
Merge pull request #92 from microsoft/master
chicm-ms Jun 22, 2020
8139c9c
Merge pull request #93 from microsoft/master
chicm-ms Jun 23, 2020
43419d7
Merge pull request #94 from microsoft/master
chicm-ms Jun 28, 2020
6b6ee55
Merge pull request #95 from microsoft/master
chicm-ms Jun 28, 2020
1b975e0
Merge pull request #96 from microsoft/master
chicm-ms Jun 28, 2020
c8f3c5d
Merge pull request #97 from microsoft/master
chicm-ms Jun 29, 2020
4c306f0
Merge pull request #98 from microsoft/master
chicm-ms Jun 30, 2020
64de4c2
Merge pull request #99 from microsoft/master
chicm-ms Jun 30, 2020
0e5d3ac
Merge pull request #100 from microsoft/master
chicm-ms Jul 1, 2020
4a52608
Merge pull request #101 from microsoft/master
chicm-ms Jul 3, 2020
208b1ee
Merge pull request #102 from microsoft/master
chicm-ms Jul 8, 2020
e7b1a2e
Merge pull request #103 from microsoft/master
chicm-ms Jul 10, 2020
57bcc85
Merge pull request #104 from microsoft/master
chicm-ms Jul 22, 2020
030f5ef
Merge pull request #105 from microsoft/master
chicm-ms Jul 29, 2020
058c8b7
Merge pull request #106 from microsoft/master
chicm-ms Aug 2, 2020
9abd8c8
Merge pull request #107 from microsoft/master
chicm-ms Aug 10, 2020
13c6623
Merge pull request #108 from microsoft/master
chicm-ms Aug 11, 2020
b50b41e
Merge pull request #109 from microsoft/master
chicm-ms Aug 12, 2020
78f1418
Merge pull request #110 from microsoft/master
chicm-ms Aug 13, 2020
74acc8b
Merge pull request #111 from microsoft/master
chicm-ms Aug 17, 2020
5bf416a
Merge pull request #112 from microsoft/master
chicm-ms Aug 24, 2020
4a207f9
Merge pull request #113 from microsoft/master
chicm-ms Sep 3, 2020
8aef8fa
AMC supports resnet
chicm-ms Sep 8, 2020
7d16857
updates
chicm-ms Sep 11, 2020
7be897b
Merge pull request #114 from microsoft/master
chicm-ms Sep 16, 2020
f974b2c
Merge pull request #115 from microsoft/master
chicm-ms Sep 17, 2020
7d239df
use nni speedup
chicm-ms Sep 18, 2020
2860176
updates
chicm-ms Sep 18, 2020
0c2f59b
Merge pull request #116 from microsoft/master
chicm-ms Sep 21, 2020
b0ba247
updates
chicm-ms Sep 22, 2020
3c5cef2
Merge pull request #117 from microsoft/master
chicm-ms Sep 25, 2020
85a879e
Merge pull request #118 from microsoft/master
chicm-ms Oct 9, 2020
1cd19fa
Merge branch 'master' into amc_resnet
chicm-ms Oct 9, 2020
0279aee
change print to logging
chicm-ms Oct 9, 2020
d066c1f
updates
chicm-ms Oct 9, 2020
000bedc
updates
chicm-ms Oct 10, 2020
857cb55
updates
chicm-ms Oct 10, 2020
e486c4d
Merge pull request #119 from microsoft/master
chicm-ms Oct 10, 2020
f670f3b
Merge branch 'master' into amc_resnet
chicm-ms Oct 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/en_US/Compression/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,16 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AMCPruner
```

### Reproduced Experiment

We implemented one of the experiments in [AMC: AutoML for Model Compression and Acceleration on Mobile Devices](https://arxiv.org/pdf/1802.03494.pdf), we pruned **MobileNet** to 50% FLOPS for ImageNet in the paper. Our experiments results are as follows:

| Model | Top 1 acc.(paper/ours) | Top 5 acc. (paper/ours) | FLOPS |
| ------------- | --------------| -------------- | ----- |
| MobileNet | 70.5% / 69.9% | 89.3% / 89.1% | 50% |

The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/amc/)

## ADMM Pruner
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
by decomposing the original nonconvex problem into two subproblems that can be solved iteratively. In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.
Expand Down
40 changes: 21 additions & 19 deletions examples/model_compress/amc/amc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import torch.nn as nn

from torchvision.models import resnet
from nni.compression.torch import AMCPruner
from data import get_split_dataset
from utils import AverageMeter, accuracy
Expand All @@ -16,7 +16,8 @@

def parse_args():
parser = argparse.ArgumentParser(description='AMC search script')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2'], help='model to prune')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
Expand All @@ -28,27 +29,29 @@ def parse_args():
parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
parser.add_argument('--job', default='train_export', type=str, choices=['train_export', 'export_only'],
help='search best pruning policy and export or just export model with searched policy')
parser.add_argument('--export_path', default=None, type=str, help='path for exporting models')
parser.add_argument('--searched_model_path', default=None, type=str, help='path for searched best wrapped model')
parser.add_argument('--suffix', default=None, type=str, help='suffix of auto-generated log directory')

return parser.parse_args()


def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
if model == 'mobilenet' and dataset == 'imagenet':
from mobilenet import MobileNet
net = MobileNet(n_class=1000)
elif model == 'mobilenetv2' and dataset == 'imagenet':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=1000)
elif model == 'mobilenet' and dataset == 'cifar10':
if dataset == 'imagenet':
n_class = 1000
elif dataset == 'cifar10':
n_class = 10
else:
raise ValueError('unsupported dataset')

if model == 'mobilenet':
from mobilenet import MobileNet
net = MobileNet(n_class=10)
elif model == 'mobilenetv2' and dataset == 'cifar10':
net = MobileNet(n_class=n_class)
elif model == 'mobilenetv2':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=10)
net = MobileNetV2(n_class=n_class)
elif model.startswith('resnet'):
net = resnet.__dict__[model](pretrained=True)
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, n_class)
zheng-ningxin marked this conversation as resolved.
Show resolved Hide resolved
else:
raise NotImplementedError
if checkpoint_path:
Expand Down Expand Up @@ -130,7 +133,6 @@ def validate(val_loader, model, verbose=False):
}]
pruner = AMCPruner(
model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset,
train_episode=args.train_episode, job=args.job, export_path=args.export_path,
searched_model_path=args.searched_model_path,
flops_ratio=args.flops_ratio, lbound=args.lbound, rbound=args.rbound)
train_episode=args.train_episode, flops_ratio=args.flops_ratio, lbound=args.lbound,
rbound=args.rbound, suffix=args.suffix)
pruner.compress()
34 changes: 19 additions & 15 deletions examples/model_compress/amc/amc_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
from nni.compression.torch import ModelSpeedup

from data import get_dataset
from utils import AverageMeter, accuracy, progress_bar
Expand All @@ -28,17 +29,19 @@ def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=1, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--n_worker', default=4, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='exp', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=50, type=int, help='number of epochs to train')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--n_worker', default=32, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='cos', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=150, type=int, help='number of epochs to train')
parser.add_argument('--wd', default=4e-5, type=float, help='weight decay')
parser.add_argument('--seed', default=None, type=int, help='random seed to set')
parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
# resume
parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to fine tune')
parser.add_argument('--mask_path', default=None, type=str, help='mask path for speedup')

# run eval
parser.add_argument('--eval', action='store_true', help='Simply run eval')
parser.add_argument('--calc_flops', action='store_true', help='Calculate flops')
Expand All @@ -56,20 +59,21 @@ def get_model(args):
raise NotImplementedError

if args.model_type == 'mobilenet':
net = MobileNet(n_class=n_class).cuda()
net = MobileNet(n_class=n_class)
elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class).cuda()
net = MobileNetV2(n_class=n_class)
else:
raise NotImplementedError

if args.ckpt_path is not None:
# the checkpoint can be a saved whole model object exported by amc_search.py, or a state_dict
# the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
ckpt = torch.load(args.ckpt_path)
if type(ckpt) == dict:
net.load_state_dict(ckpt['state_dict'])
else:
net = ckpt
net.load_state_dict(torch.load(args.ckpt_path))
if args.mask_path is not None:
SZ = 224 if args.dataset == 'imagenet' else 32
data = torch.randn(2, 3, SZ, SZ)
ms = ModelSpeedup(net, data, args.mask_path)
ms.speedup_model()

net.to(args.device)
if torch.cuda.is_available() and args.n_gpu > 1:
Expand Down Expand Up @@ -204,7 +208,7 @@ def save_checkpoint(state, is_best, checkpoint_dir='.'):

if args.calc_flops:
IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE)
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE, args.device)
print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6))
exit(0)

Expand Down
114 changes: 41 additions & 73 deletions src/sdk/pynni/nni/compression/torch/pruning/amc/amc_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT license.

import os
import logging
from copy import deepcopy
from argparse import Namespace
import numpy as np
Expand All @@ -15,6 +16,8 @@

torch.backends.cudnn.deterministic = True

_logger = logging.getLogger(__name__)

class AMCPruner(Pruner):
"""
A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices.
Expand All @@ -38,13 +41,6 @@ class AMCPruner(Pruner):
Data loader of validation dataset.
suffix: str
suffix to help you remember what experiment you ran. Default: None.
job: str
train_export: search best pruned model and export after search.
export_only: export a searched model, searched_model_path and export_path must be specified.
searched_model_path: str
when job == export_only, use searched_model_path to specify the path of the searched model.
export_path: str
path for exporting models

# parameters for pruning environment
model_type: str
Expand Down Expand Up @@ -118,9 +114,6 @@ def __init__(
evaluator,
val_loader,
suffix=None,
job='train_export',
export_path=None,
searched_model_path=None,
model_type='mobilenet',
dataset='cifar10',
flops_ratio=0.5,
Expand Down Expand Up @@ -149,9 +142,8 @@ def __init__(
epsilon=50000,
seed=None):

self.job = job
self.searched_model_path = searched_model_path
self.export_path = export_path
self.val_loader = val_loader
self.evaluator = evaluator

if seed is not None:
np.random.seed(seed)
Expand All @@ -165,11 +157,9 @@ def __init__(
# build folder and logs
base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio)
if suffix is not None:
base_folder_name = base_folder_name + '_' + suffix
self.output_dir = get_output_folder(output_dir, base_folder_name)

if self.export_path is None:
self.export_path = os.path.join(self.output_dir, '{}_r{}_exported.pth'.format(model_type, flops_ratio))
self.output_dir = os.path.join(output_dir, base_folder_name + '-' + suffix)
else:
self.output_dir = get_output_folder(output_dir, base_folder_name)

self.env_args = Namespace(
model_type=model_type,
Expand All @@ -182,47 +172,42 @@ def __init__(
channel_round=channel_round,
output=self.output_dir
)

self.env = ChannelPruningEnv(
self, evaluator, val_loader, checkpoint, args=self.env_args)

if self.job == 'train_export':
print('=> Saving logs to {}'.format(self.output_dir))
self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
print('=> Output path: {}...'.format(self.output_dir))

nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here

rmsize = rmsize * len(self.env.prunable_idx) # for each layer
print('** Actual replay buffer size: {}'.format(rmsize))

self.ddpg_args = Namespace(
hidden1=hidden1,
hidden2=hidden2,
lr_c=lr_c,
lr_a=lr_a,
warmup=warmup,
discount=discount,
bsize=bsize,
rmsize=rmsize,
window_length=window_length,
tau=tau,
init_delta=init_delta,
delta_decay=delta_decay,
max_episode_length=max_episode_length,
debug=debug,
train_episode=train_episode,
epsilon=epsilon
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)
_logger.info('=> Saving logs to %s', self.output_dir)
self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
_logger.info('=> Output path: %s...', self.output_dir)

nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here

rmsize = rmsize * len(self.env.prunable_idx) # for each layer
_logger.info('** Actual replay buffer size: %d', rmsize)

self.ddpg_args = Namespace(
hidden1=hidden1,
hidden2=hidden2,
lr_c=lr_c,
lr_a=lr_a,
warmup=warmup,
discount=discount,
bsize=bsize,
rmsize=rmsize,
window_length=window_length,
tau=tau,
init_delta=init_delta,
delta_decay=delta_decay,
max_episode_length=max_episode_length,
debug=debug,
train_episode=train_episode,
epsilon=epsilon
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)


def compress(self):
if self.job == 'train_export':
self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)
self.export_pruned_model()
self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)

def train(self, num_episode, agent, env, output_dir):
agent.is_training = True
Expand Down Expand Up @@ -263,12 +248,11 @@ def train(self, num_episode, agent, env, output_dir):
observation = deepcopy(observation2)

if done: # end of episode
print(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format(
_logger.info(
'#%d: episode_reward: %.4f acc: %.4f, ratio: %.4f',
episode, episode_reward,
info['accuracy'],
info['compress_ratio']
)
)
self.text_writer.write(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(
Expand Down Expand Up @@ -310,19 +294,3 @@ def train(self, num_episode, agent, env, output_dir):
self.text_writer.write('best reward: {}\n'.format(env.best_reward))
self.text_writer.write('best policy: {}\n'.format(env.best_strategy))
self.text_writer.close()

def export_pruned_model(self):
if self.searched_model_path is None:
wrapper_model_ckpt = os.path.join(self.output_dir, 'best_wrapped_model.pth')
else:
wrapper_model_ckpt = self.searched_model_path
self.env.reset()
self.bound_model.load_state_dict(torch.load(wrapper_model_ckpt))

print('validate searched model:', self.env._validate(self.env._val_loader, self.env.model))
self.env.export_model()
self._unwrap_model()
print('validate exported model:', self.env._validate(self.env._val_loader, self.env.model))

torch.save(self.bound_model, self.export_path)
print('exported model saved to: {}'.format(self.export_path))
Loading