3838
3939from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
4040from tensor2tensor .data_generators import generator_utils
41+ from tensor2tensor .envs import env_problem_utils
4142from tensor2tensor .utils import registry
4243from tensor2tensor .utils import usr_dir
4344
5455# Improrting here to prevent pylint from ungrouped-imports warning.
5556import tensorflow as tf # pylint: disable=g-import-not-at-top
5657
57-
5858flags = tf .flags
5959FLAGS = flags .FLAGS
6060
6565 "The name of the problem to generate data for." )
6666flags .DEFINE_string ("exclude_problems" , "" ,
6767 "Comma-separates list of problems to exclude." )
68- flags .DEFINE_integer ("num_shards" , 0 , "How many shards to use. Ignored for "
69- "registered Problems." )
68+ flags .DEFINE_integer (
69+ "num_shards" , 0 , "How many shards to use. Ignored for "
70+ "registered Problems." )
7071flags .DEFINE_integer ("max_cases" , 0 ,
7172 "Maximum number of cases to generate (unbounded if 0)." )
73+ flags .DEFINE_integer (
74+ "env_problem_max_env_steps" , 0 ,
75+ "Maximum number of steps to take for environment-based problems. "
76+ "Actions are chosen randomly" )
77+ flags .DEFINE_integer (
78+ "env_problem_batch_size" , 0 ,
79+ "Number of environments to simulate for environment-based problems." )
7280flags .DEFINE_bool ("only_list" , False ,
7381 "If true, we only list the problems that will be generated." )
7482flags .DEFINE_integer ("random_seed" , 429459 , "Random seed to use." )
7886flags .DEFINE_integer (
7987 "num_concurrent_processes" , None ,
8088 "Applies only to problems for which multiprocess_generate=True." )
81- flags .DEFINE_string ("t2t_usr_dir" , "" ,
82- "Path to a Python module that will be imported. The "
83- "__init__.py file should include the necessary imports. "
84- "The imported files should contain registrations, "
85- "e.g. @registry.register_problem calls, that will then be "
86- "available to t2t-datagen." )
89+ flags .DEFINE_string (
90+ "t2t_usr_dir" , "" , "Path to a Python module that will be imported. The "
91+ "__init__.py file should include the necessary imports. "
92+ "The imported files should contain registrations, "
93+ "e.g. @registry.register_problem calls, that will then be "
94+ "available to t2t-datagen." )
8795
8896# Mapping from problems that we can generate data for to their generators.
8997# pylint: disable=g-long-lambda
9098_SUPPORTED_PROBLEM_GENERATORS = {
91- "algorithmic_algebra_inverse" : (
92- lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
93- lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 ),
94- lambda : None ), # test set
95- "parsing_english_ptb8k" : (
96- lambda : wsj_parsing .parsing_token_generator (
99+ "algorithmic_algebra_inverse" :
100+ ( lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
101+ lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 ),
102+ lambda : None ), # test set
103+ "parsing_english_ptb8k" :
104+ ( lambda : wsj_parsing .parsing_token_generator (
97105 FLAGS .data_dir , FLAGS .tmp_dir , True , 2 ** 13 , 2 ** 9 ),
98- lambda : wsj_parsing .parsing_token_generator (
99- FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 13 , 2 ** 9 ),
100- lambda : None ), # test set
101- "parsing_english_ptb16k" : (
102- lambda : wsj_parsing .parsing_token_generator (
106+ lambda : wsj_parsing .parsing_token_generator (
107+ FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 13 , 2 ** 9 ),
108+ lambda : None ), # test set
109+ "parsing_english_ptb16k" :
110+ ( lambda : wsj_parsing .parsing_token_generator (
103111 FLAGS .data_dir , FLAGS .tmp_dir , True , 2 ** 14 , 2 ** 9 ),
104- lambda : wsj_parsing .parsing_token_generator (
105- FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 14 , 2 ** 9 ),
106- lambda : None ), # test set
107- "inference_snli32k" : (
108- lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
109- lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
110- lambda : None ), # test set
111- "audio_timit_characters_test" : (
112- lambda : audio .timit_generator (
113- FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ),
114- lambda : audio .timit_generator (
115- FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ),
116- lambda : None ), # test set
117- "audio_timit_tokens_8k_test" : (
118- lambda : audio .timit_generator (
119- FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ,
120- vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 ),
121- lambda : audio .timit_generator (
122- FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ,
123- vocab_filename = "vocab.endefr.%d" % 2 ** 13 , vocab_size = 2 ** 13 ),
124- lambda : None ), # test set
125- "audio_timit_tokens_32k_test" : (
126- lambda : audio .timit_generator (
127- FLAGS .data_dir , FLAGS .tmp_dir , True , 1718 ,
128- vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 ),
129- lambda : audio .timit_generator (
130- FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ,
131- vocab_filename = "vocab.endefr.%d" % 2 ** 15 , vocab_size = 2 ** 15 ),
132- lambda : None ), # test set
112+ lambda : wsj_parsing .parsing_token_generator (
113+ FLAGS .data_dir , FLAGS .tmp_dir , False , 2 ** 14 , 2 ** 9 ),
114+ lambda : None ), # test set
115+ "inference_snli32k" :
116+ (lambda : snli .snli_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
117+ lambda : snli .snli_token_generator (FLAGS .tmp_dir , False , 2 ** 15 ),
118+ lambda : None ), # test set
119+ "audio_timit_characters_test" : (lambda : audio .timit_generator (
120+ FLAGS .data_dir , FLAGS .tmp_dir , True , 1718
121+ ), lambda : audio .timit_generator (FLAGS .data_dir , FLAGS .tmp_dir , False , 626 ),
122+ lambda : None ), # test set
123+ "audio_timit_tokens_8k_test" : (lambda : audio .timit_generator (
124+ FLAGS .data_dir ,
125+ FLAGS .tmp_dir ,
126+ True ,
127+ 1718 ,
128+ vocab_filename = "vocab.endefr.%d" % 2 ** 13 ,
129+ vocab_size = 2 ** 13 ), lambda : audio .timit_generator (
130+ FLAGS .data_dir ,
131+ FLAGS .tmp_dir ,
132+ False ,
133+ 626 ,
134+ vocab_filename = "vocab.endefr.%d" % 2 ** 13 ,
135+ vocab_size = 2 ** 13 ), lambda : None ), # test set
136+ "audio_timit_tokens_32k_test" : (lambda : audio .timit_generator (
137+ FLAGS .data_dir ,
138+ FLAGS .tmp_dir ,
139+ True ,
140+ 1718 ,
141+ vocab_filename = "vocab.endefr.%d" % 2 ** 15 ,
142+ vocab_size = 2 ** 15 ), lambda : audio .timit_generator (
143+ FLAGS .data_dir ,
144+ FLAGS .tmp_dir ,
145+ False ,
146+ 626 ,
147+ vocab_filename = "vocab.endefr.%d" % 2 ** 15 ,
148+ vocab_size = 2 ** 15 ), lambda : None ), # test set
133149}
134150
135151# pylint: enable=g-long-lambda
@@ -147,7 +163,8 @@ def main(_):
147163
148164 # Calculate the list of problems to generate.
149165 problems = sorted (
150- list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_base_problems ())
166+ list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_base_problems () +
167+ registry .list_env_problems ())
151168 for exclude in FLAGS .exclude_problems .split ("," ):
152169 if exclude :
153170 problems = [p for p in problems if exclude not in p ]
@@ -169,8 +186,9 @@ def main(_):
169186
170187 if not problems :
171188 problems_str = "\n * " .join (
172- sorted (list (_SUPPORTED_PROBLEM_GENERATORS ) +
173- registry .list_base_problems ()))
189+ sorted (
190+ list (_SUPPORTED_PROBLEM_GENERATORS ) +
191+ registry .list_base_problems () + registry .list_env_problems ()))
174192 error_msg = ("You must specify one of the supported problems to "
175193 "generate data for:\n * " + problems_str + "\n " )
176194 error_msg += ("TIMIT and parsing need data_sets specified with "
@@ -179,24 +197,28 @@ def main(_):
179197
180198 if not FLAGS .data_dir :
181199 FLAGS .data_dir = tempfile .gettempdir ()
182- tf .logging .warning ("It is strongly recommended to specify --data_dir. "
183- "Data will be written to default data_dir=%s." ,
184- FLAGS .data_dir )
200+ tf .logging .warning (
201+ "It is strongly recommended to specify -- data_dir. "
202+ "Data will be written to default data_dir=%s." , FLAGS .data_dir )
185203 FLAGS .data_dir = os .path .expanduser (FLAGS .data_dir )
186204 tf .gfile .MakeDirs (FLAGS .data_dir )
187205
188- tf .logging .info ("Generating problems:\n %s"
189- % registry .display_list_by_prefix (problems ,
190- starting_spaces = 4 ))
206+ tf .logging .info ("Generating problems:\n %s" %
207+ registry .display_list_by_prefix (problems , starting_spaces = 4 ))
191208 if FLAGS .only_list :
192209 return
193210 for problem in problems :
194211 set_random_seed ()
195212
196213 if problem in _SUPPORTED_PROBLEM_GENERATORS :
197214 generate_data_for_problem (problem )
198- else :
215+ elif problem in registry . list_base_problems () :
199216 generate_data_for_registered_problem (problem )
217+ elif problem in registry .list_env_problems ():
218+ generate_data_for_env_problem (problem )
219+ else :
220+ tf .logging .error ("Problem %s is not a supported problem for datagen." ,
221+ problem )
200222
201223
202224def generate_data_for_problem (problem ):
@@ -235,6 +257,24 @@ def generate_data_in_process(arg):
235257 problem .generate_data (data_dir , tmp_dir , task_id )
236258
237259
260+ def generate_data_for_env_problem (problem_name ):
261+ """Generate data for `EnvProblem`s."""
262+ assert FLAGS .env_problem_max_env_steps > 0 , ("--env_problem_max_env_steps "
263+ "should be greater than zero" )
264+ assert FLAGS .env_problem_batch_size > 0 , ("--env_problem_batch_size should be"
265+ " greather than zero" )
266+ problem = registry .env_problem (problem_name )
267+ task_id = None if FLAGS .task_id < 0 else FLAGS .task_id
268+ data_dir = os .path .expanduser (FLAGS .data_dir )
269+ tmp_dir = os .path .expanduser (FLAGS .tmp_dir )
270+ # TODO(msaffar): Handle large values for env_problem_batch_size where we
271+ # cannot create that many environments within the same process.
272+ problem .initialize (batch_size = FLAGS .env_problem_batch_size )
273+ env_problem_utils .play_env_problem_randomly (
274+ problem , num_steps = FLAGS .env_problem_max_env_steps )
275+ problem .generate_data (data_dir = data_dir , tmp_dir = tmp_dir , task_id = task_id )
276+
277+
238278def generate_data_for_registered_problem (problem_name ):
239279 """Generate data for a registered problem."""
240280 tf .logging .info ("Generating data for %s." , problem_name )
@@ -260,6 +300,7 @@ def generate_data_for_registered_problem(problem_name):
260300 else :
261301 problem .generate_data (data_dir , tmp_dir , task_id )
262302
303+
263304if __name__ == "__main__" :
264305 tf .logging .set_verbosity (tf .logging .INFO )
265306 tf .app .run ()
0 commit comments