We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5bcb830 commit 21233e1Copy full SHA for 21233e1
train.py
@@ -65,6 +65,9 @@
65
" new model instance.")
66
67
# Training flags.
68
+ flags.DEFINE_integer("num_gpu", 1,
69
+ "The maximum number of GPU devices to use for training. "
70
+ "Flag only applies if GPUs are installed")
71
flags.DEFINE_integer("batch_size", 1024,
72
"How many examples to process per batch for training.")
73
flags.DEFINE_string("label_loss", "CrossEntropyLoss",
@@ -220,6 +223,7 @@ def build_graph(reader,
220
223
221
224
local_device_protos = device_lib.list_local_devices()
222
225
gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
226
+ gpus = gpus[:FLAGS.num_gpu]
227
num_gpus = len(gpus)
228
229
if num_gpus > 0:
0 commit comments