This repository has been archived by the owner on May 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
executable file
·108 lines (93 loc) · 2.76 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2020 Harish Rajagopal <[email protected]>
#
# SPDX-License-Identifier: MIT
"""Train a classifier for FID."""
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
from pathlib import Path
from typing import Final
import tensorflow as tf
from tensorflow.keras.mixed_precision import set_global_policy
from gan.data import get_dataset
from gan.models import Classifier
from gan.training import ClassifierTrainer
from gan.utils import load_config, setup_dirs
CONFIG: Final = "config-cls.toml"
def main(args: Namespace) -> None:
"""Run the main program.
Arguments:
args: The object containing the commandline arguments
"""
config = load_config(args.config)
strategy = tf.distribute.MirroredStrategy()
if config.mixed_precision:
set_global_policy("mixed_float16")
train_dataset, test_dataset = get_dataset(
args.data_path, config.cls_batch_size
)
with strategy.scope():
model = Classifier(config)
# Save each run into a directory by its timestamp.
log_dir = setup_dirs(
dirs=[args.save_dir],
dirs_to_tstamp=[args.log_dir],
config=config,
file_name=CONFIG,
)[0]
trainer = ClassifierTrainer(model, strategy, config=config)
trainer.train(
train_dataset,
test_dataset,
log_dir=log_dir,
record_eps=args.record_eps,
save_dir=args.save_dir,
save_steps=args.save_steps,
log_graph=args.log_graph,
)
if __name__ == "__main__":
parser = ArgumentParser(
description="Train a classifier for FID",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--data-path",
type=Path,
default="./datasets/MNIST/",
help="path to the dataset",
)
parser.add_argument(
"-c",
"--config",
type=Path,
help="Path to a TOML config containing hyper-parameter values",
)
parser.add_argument(
"--save-dir",
type=Path,
default="./checkpoints/",
help="directory where to save model",
)
parser.add_argument(
"--save-steps",
type=int,
default=5000,
help="the frequency of saving the model (in steps)",
)
parser.add_argument(
"--record-eps",
type=int,
default=5,
help="the frequency of recording summaries (in epochs)",
)
parser.add_argument(
"--log-graph",
action="store_true",
help="whether to log the graph of the model",
)
parser.add_argument(
"--log-dir",
type=Path,
default="./logs/classifier",
help="directory where to write event logs",
)
main(parser.parse_args())