Skip to content

Commit 21233e1

Browse files
committed
Defaulting GPUs to 1
1 parent 5bcb830 commit 21233e1

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

train.py

+4
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@
6565
" new model instance.")
6666

6767
# Training flags.
68+
flags.DEFINE_integer("num_gpu", 1,
69+
"The maximum number of GPU devices to use for training. "
70+
"Flag only applies if GPUs are installed")
6871
flags.DEFINE_integer("batch_size", 1024,
6972
"How many examples to process per batch for training.")
7073
flags.DEFINE_string("label_loss", "CrossEntropyLoss",
@@ -220,6 +223,7 @@ def build_graph(reader,
220223

221224
local_device_protos = device_lib.list_local_devices()
222225
gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
226+
gpus = gpus[:FLAGS.num_gpu]
223227
num_gpus = len(gpus)
224228

225229
if num_gpus > 0:

0 commit comments

Comments
 (0)