This repository was archived by the owner on Jul 7, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 4 files changed +3
-12
lines changed Expand file tree Collapse file tree 4 files changed +3
-12
lines changed Original file line number Diff line number Diff line change 6565flags .DEFINE_integer ("intra_op_parallelism_threads" , 0 ,
6666 "Number of intra_op_parallelism_threads to use for CPU. "
6767 "See TensorFlow config.proto for details." )
68- # TODO(hinsu): Enable DistributionStrategy by default once performance gap
69- # between DistributionStrategy and Parallelism is resolved.
7068flags .DEFINE_bool (
71- "optionally_use_dist_strat" , False ,
69+ "optionally_use_dist_strat" , True ,
7270 "Whether to use TensorFlow DistributionStrategy instead of explicitly "
7371 "replicating the model. DistributionStrategy is used only if the "
7472 "model replication configuration is supported by the DistributionStrategy." )
@@ -239,7 +237,6 @@ def create_run_config(hp, output_dir=None):
239237 keep_checkpoint_every_n_hours = FLAGS .keep_checkpoint_every_n_hours ,
240238 num_gpus = FLAGS .worker_gpu ,
241239 gpu_order = FLAGS .gpu_order ,
242- shard_to_cpu = FLAGS .locally_shard_to_cpu ,
243240 num_async_replicas = FLAGS .worker_replicas ,
244241 gpu_mem_fraction = FLAGS .worker_gpu_memory_fraction ,
245242 enable_graph_rewriter = FLAGS .enable_graph_rewriter ,
Original file line number Diff line number Diff line change @@ -70,7 +70,6 @@ def data_parallelism(daisy_chain_variables=True,
7070 worker_replicas = 1 ,
7171 worker_id = 0 ,
7272 gpu_order = "" ,
73- locally_shard_to_cpu = False ,
7473 worker_job = "/job:localhost" ,
7574 no_data_parallelism = False ):
7675 """See data_parallelism_from_flags."""
@@ -141,7 +140,7 @@ def _replica_device_setter(worker_device):
141140 "Schedule=%s. Assuming that training is running on a single machine." ,
142141 schedule )
143142 datashard_devices = ["gpu:%d" % d for d in _gpu_order (worker_gpu )]
144- if locally_shard_to_cpu or worker_gpu < 1 :
143+ if worker_gpu < 1 :
145144 datashard_devices += ["cpu:0" ]
146145 caching_devices = None
147146 elif sync and ps_replicas > 0 :
Original file line number Diff line number Diff line change 9292flags .DEFINE_integer ("eval_throttle_seconds" , 600 ,
9393 "Do not re-evaluate unless the last evaluation was started"
9494 " at least this many seconds ago." )
95- flags .DEFINE_bool ("locally_shard_to_cpu" , False ,
96- "Use CPU as a sharding device running locally. This allows "
97- "to test sharded model construction on a machine with 1 GPU." )
9895flags .DEFINE_bool ("sync" , False , "Sync compute on PS." )
9996flags .DEFINE_string ("worker_job" , "/job:localhost" , "name of worker job" )
10097flags .DEFINE_integer ("worker_gpu" , 1 , "How many GPUs to use." )
Original file line number Diff line number Diff line change @@ -145,7 +145,6 @@ def create_run_config(model_name,
145145 keep_checkpoint_every_n_hours = 10000 ,
146146 num_gpus = 1 ,
147147 gpu_order = "" ,
148- shard_to_cpu = False ,
149148 num_async_replicas = 1 ,
150149 enable_graph_rewriter = False ,
151150 gpu_mem_fraction = 0.95 ,
@@ -239,7 +238,7 @@ def create_run_config(model_name,
239238 optionally_use_dist_strat and
240239 t2t_model .T2TModel .has_symmetric_shards (model_name ) and
241240 not no_data_parallelism and ps_replicas == 0 and ps_gpu == 0 and
242- num_async_replicas == 1 and not shard_to_cpu )
241+ num_async_replicas == 1 )
243242
244243 if use_distribution_strategy :
245244 tf .logging .info (
@@ -262,7 +261,6 @@ def create_run_config(model_name,
262261 worker_replicas = num_async_replicas ,
263262 worker_id = worker_id ,
264263 gpu_order = gpu_order ,
265- locally_shard_to_cpu = shard_to_cpu ,
266264 worker_job = worker_job ,
267265 no_data_parallelism = no_data_parallelism )
268266
You can’t perform that action at this time.
0 commit comments