-
Notifications
You must be signed in to change notification settings - Fork 352
/
Copy pathmain.rs
72 lines (63 loc) · 2.9 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// This is inspired by the Neural Style tutorial from PyTorch.org
// https://pytorch.org/tutorials/advanced/neural_style_tutorial.html
// The pre-trained weights for the VGG16 model can be downloaded from:
// https://github.com/LaurentMazare/tch-rs/releases/download/mw/vgg16.ot
use anyhow::{bail, Result};
use tch::vision::{imagenet, vgg};
use tch::{nn, nn::OptimizerConfig, Device, Tensor};
const STYLE_WEIGHT: f64 = 1e6;
const LEARNING_RATE: f64 = 1e-1;
const TOTAL_STEPS: i64 = 3000;
const STYLE_INDEXES: [usize; 5] = [0, 2, 5, 7, 10];
const CONTENT_INDEXES: [usize; 1] = [7];
fn gram_matrix(m: &Tensor) -> Tensor {
let (a, b, c, d) = m.size4().unwrap();
let m = m.view([a * b, c * d]);
let g = m.matmul(&m.tr());
g / (a * b * c * d)
}
fn style_loss(m1: &Tensor, m2: &Tensor) -> Tensor {
gram_matrix(m1).mse_loss(&gram_matrix(m2), tch::Reduction::Mean)
}
pub fn main() -> Result<()> {
let device = Device::cuda_if_available();
let args: Vec<_> = std::env::args().collect();
let (style_img, content_img, weights) = match args.as_slice() {
[_, s, c, w] => (s.to_owned(), c.to_owned(), w.to_owned()),
_ => bail!("usage: main style.jpg content.jpg vgg16.ot"),
};
let mut net_vs = tch::nn::VarStore::new(device);
let net = vgg::vgg16(&net_vs.root(), imagenet::CLASS_COUNT);
net_vs.load(&weights).unwrap_or_else(|_| panic!("Could not load weights file {}", &weights));
net_vs.freeze();
let style_img = imagenet::load_image(&style_img)
.unwrap_or_else(|_| panic!("Could not load the style file {}", &style_img))
.unsqueeze(0)
.to_device(device);
let content_img = imagenet::load_image(&content_img)
.unwrap_or_else(|_| panic!("Could not load the content file {}", &content_img))
.unsqueeze(0)
.to_device(device);
let max_layer = STYLE_INDEXES.iter().max().unwrap() + 1;
let style_layers = net.forward_all_t(&style_img, false, Some(max_layer));
let content_layers = net.forward_all_t(&content_img, false, Some(max_layer));
let vs = nn::VarStore::new(device);
let input_var = vs.root().var_copy("img", &content_img);
let mut opt = nn::Adam::default().build(&vs, LEARNING_RATE)?;
for step_idx in 1..(1 + TOTAL_STEPS) {
let input_layers = net.forward_all_t(&input_var, false, Some(max_layer));
let style_loss: Tensor =
STYLE_INDEXES.iter().map(|&i| style_loss(&input_layers[i], &style_layers[i])).sum();
let content_loss: Tensor = CONTENT_INDEXES
.iter()
.map(|&i| input_layers[i].mse_loss(&content_layers[i], tch::Reduction::Mean))
.sum();
let loss = style_loss * STYLE_WEIGHT + content_loss;
opt.backward_step(&loss);
if step_idx % 1000 == 0 {
println!("{} {}", step_idx, f64::try_from(loss)?);
imagenet::save_image(&input_var, format!("out{step_idx}.jpg"))?;
}
}
Ok(())
}