@@ -89,6 +89,10 @@ def get_arguments():
89
89
default = "lcrrot" ,
90
90
)
91
91
92
+ parser .add_argument (
93
+ "--job-id" , "-jid" , type = str , help = "ID of the job being submitted"
94
+ )
95
+
92
96
parser .add_argument (
93
97
"--job-dir" ,
94
98
help = "GCS location to write checkpoints to and export models" ,
@@ -195,7 +199,7 @@ def write_gcloud_config(args):
195
199
args_dict = vars (args )
196
200
args_list = []
197
201
for (key , value ) in args_dict .items ():
198
- if not value :
202
+ if not value or key == "job_id" :
199
203
continue
200
204
if isinstance (value , list ):
201
205
value = " " .join (map (str , value ))
@@ -206,7 +210,12 @@ def write_gcloud_config(args):
206
210
"jobId" : "my_job" ,
207
211
"labels" : {"type" : "dev" , "owner" : "sean" },
208
212
"trainingInput" : {
209
- "scaleTier" : "BASIC" ,
213
+ "scaleTier" : "CUSTOM" ,
214
+ "masterType" : "standard_gpu" ,
215
+ "workerType" : "standard_gpu" ,
216
+ "parameterServerType" : "standard_gpu" ,
217
+ "workerCount" : 4 ,
218
+ "parameterServerCount" : 3 ,
210
219
"pythonVersion" : "3.5" ,
211
220
"runtimeVersion" : "1.10" ,
212
221
"region" : "europe-west1" ,
@@ -218,16 +227,16 @@ def write_gcloud_config(args):
218
227
dump (gcloud_config , config_file , indent = 4 )
219
228
220
229
221
- def write_gcloud_cmd_script ():
230
+ def write_gcloud_cmd_script (args ):
222
231
gcloud_cmd = """gcloud ml-engine jobs submit training {job_name} \\
223
232
--job-dir={job_dir} \\
224
233
--module-name={module_name} \\
225
234
--staging-bucket={staging_bucket} \\
226
235
--packages={package_name} \\
227
236
--config={config_path} \\
228
237
--stream-logs""" .format (
229
- job_name = "testing_job_script_3" ,
230
- job_dir = "gs://tsaplay-bucket/testing_job_script" ,
238
+ job_name = args . job_id ,
239
+ job_dir = "gs://tsaplay-bucket/{}" . format ( args . job_id ) ,
231
240
module_name = "tsaplay.task" ,
232
241
staging_bucket = "gs://tsaplay-bucket/" ,
233
242
package_name = abspath (
@@ -257,7 +266,11 @@ def write_gcloud_cmd_script():
257
266
cprnt (bow = "Copied to clipboard!" )
258
267
259
268
260
- if __name__ == "__main__" :
261
- prepare_assets (get_arguments () )
269
+ def main ( args ) :
270
+ prepare_assets (args )
262
271
sandbox .run_setup ("setup.py" , ["sdist" ])
263
- write_gcloud_cmd_script ()
272
+ write_gcloud_cmd_script (args )
273
+
274
+
275
+ if __name__ == "__main__" :
276
+ main (get_arguments ())
0 commit comments