Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9729521

Browse files
Lukasz KaiserCopybara-Service
authored andcommitted
Remove locally sharding to CPU (it was only for debugging anyway) and enable distribution strategy by default.
PiperOrigin-RevId: 219397234
1 parent 03d9876 commit 9729521

File tree

4 files changed

+3
-12
lines changed

4 files changed

+3
-12
lines changed

tensor2tensor/bin/t2t_trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,8 @@
6565
flags.DEFINE_integer("intra_op_parallelism_threads", 0,
6666
"Number of intra_op_parallelism_threads to use for CPU. "
6767
"See TensorFlow config.proto for details.")
68-
# TODO(hinsu): Enable DistributionStrategy by default once performance gap
69-
# between DistributionStrategy and Parallelism is resolved.
7068
flags.DEFINE_bool(
71-
"optionally_use_dist_strat", False,
69+
"optionally_use_dist_strat", True,
7270
"Whether to use TensorFlow DistributionStrategy instead of explicitly "
7371
"replicating the model. DistributionStrategy is used only if the "
7472
"model replication configuration is supported by the DistributionStrategy.")
@@ -239,7 +237,6 @@ def create_run_config(hp, output_dir=None):
239237
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
240238
num_gpus=FLAGS.worker_gpu,
241239
gpu_order=FLAGS.gpu_order,
242-
shard_to_cpu=FLAGS.locally_shard_to_cpu,
243240
num_async_replicas=FLAGS.worker_replicas,
244241
gpu_mem_fraction=FLAGS.worker_gpu_memory_fraction,
245242
enable_graph_rewriter=FLAGS.enable_graph_rewriter,

tensor2tensor/utils/devices.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def data_parallelism(daisy_chain_variables=True,
7070
worker_replicas=1,
7171
worker_id=0,
7272
gpu_order="",
73-
locally_shard_to_cpu=False,
7473
worker_job="/job:localhost",
7574
no_data_parallelism=False):
7675
"""See data_parallelism_from_flags."""
@@ -141,7 +140,7 @@ def _replica_device_setter(worker_device):
141140
"Schedule=%s. Assuming that training is running on a single machine.",
142141
schedule)
143142
datashard_devices = ["gpu:%d" % d for d in _gpu_order(worker_gpu)]
144-
if locally_shard_to_cpu or worker_gpu < 1:
143+
if worker_gpu < 1:
145144
datashard_devices += ["cpu:0"]
146145
caching_devices = None
147146
elif sync and ps_replicas > 0:

tensor2tensor/utils/flags.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@
9292
flags.DEFINE_integer("eval_throttle_seconds", 600,
9393
"Do not re-evaluate unless the last evaluation was started"
9494
" at least this many seconds ago.")
95-
flags.DEFINE_bool("locally_shard_to_cpu", False,
96-
"Use CPU as a sharding device running locally. This allows "
97-
"to test sharded model construction on a machine with 1 GPU.")
9895
flags.DEFINE_bool("sync", False, "Sync compute on PS.")
9996
flags.DEFINE_string("worker_job", "/job:localhost", "name of worker job")
10097
flags.DEFINE_integer("worker_gpu", 1, "How many GPUs to use.")

tensor2tensor/utils/trainer_lib.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def create_run_config(model_name,
145145
keep_checkpoint_every_n_hours=10000,
146146
num_gpus=1,
147147
gpu_order="",
148-
shard_to_cpu=False,
149148
num_async_replicas=1,
150149
enable_graph_rewriter=False,
151150
gpu_mem_fraction=0.95,
@@ -239,7 +238,7 @@ def create_run_config(model_name,
239238
optionally_use_dist_strat and
240239
t2t_model.T2TModel.has_symmetric_shards(model_name) and
241240
not no_data_parallelism and ps_replicas == 0 and ps_gpu == 0 and
242-
num_async_replicas == 1 and not shard_to_cpu)
241+
num_async_replicas == 1)
243242

244243
if use_distribution_strategy:
245244
tf.logging.info(
@@ -262,7 +261,6 @@ def create_run_config(model_name,
262261
worker_replicas=num_async_replicas,
263262
worker_id=worker_id,
264263
gpu_order=gpu_order,
265-
locally_shard_to_cpu=shard_to_cpu,
266264
worker_job=worker_job,
267265
no_data_parallelism=no_data_parallelism)
268266

0 commit comments

Comments
 (0)