Skip to content

Commit

Permalink
Fix taskcluster train scripts (#895)
Browse files Browse the repository at this point in the history
  • Loading branch information
eu9ene authored Oct 22, 2024
1 parent b0b5f25 commit 9956ef2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
4 changes: 3 additions & 1 deletion taskcluster/scripts/pipeline/train_taskcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@

ARTIFACTS_URL = "{root_url}/api/queue/v1/task/{task_id}/runs/{run_id}/artifacts"
ARTIFACT_URL = "{root_url}/api/queue/v1/task/{task_id}/runs/{run_id}/artifacts/{artifact_name}"
# TODO: consolidate everything in train.py or at least do not rely on the argument names and the number of them in the Taskcluster part
# TODO: https://github.com/mozilla/firefox-translations-training/issues/607
# The argument number where pretrained model mode is expected.
# This is 1-indexed, not 0-indexed, so it should line up with the argument
# number this is fetched in in train-taskcluster.sh
PRETRAINED_MODEL_MODE_ARG_NUMBER = 12
PRETRAINED_MODEL_MODE_ARG_NUMBER = 13
# Nothing special about 17...just a number plucked out of thin air that
# should be distinct enough to retry on.
DOWNLOAD_ERROR_EXIT_CODE = 17
Expand Down
28 changes: 16 additions & 12 deletions tests/test_train_taskcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
pytest.param(
[
"model_type",
"type",
"training_type",
"src",
"trg",
"train_set_prefix",
Expand All @@ -35,14 +35,15 @@
"best_model_metric",
"alignments",
"seed",
"mode",
"teacher_mode",
"student_model",
],
id="required_only",
),
pytest.param(
[
"model_type",
"type",
"training_type",
"src",
"trg",
"train_set_prefix",
Expand All @@ -51,7 +52,8 @@
"best_model_metric",
"alignments",
"seed",
"mode",
"teacher_mode",
"student_model",
"pretrained_model_mode",
"pretrained_model_type",
],
Expand All @@ -60,7 +62,7 @@
pytest.param(
[
"model_type",
"type",
"training_type",
"src",
"trg",
"train_set_prefix",
Expand All @@ -69,7 +71,8 @@
"best_model_metric",
"alignments",
"seed",
"mode",
"teacher_mode",
"student_model",
"pretrained_model_mode",
"pretrained_model_type",
"--foo",
Expand Down Expand Up @@ -281,17 +284,18 @@ def fake_get(url, *args, **kwargs):
model_dir = DataDir("test_train_taskcluster").path
train_taskcluster.main(
[
"model-type",
"training-type",
"model_type",
"training_type",
"src",
"trg",
"train-set-prefix",
"valid-set-prefix",
model_dir,
"best-model-metric",
"alignents",
"alignments",
"seed",
"mode",
"teacher_mode",
"student_model",
orig_pretrained_model_mode,
]
)
Expand Down Expand Up @@ -346,10 +350,10 @@ def fake_get(url, *args, **kwargs):
assert tt_mock["requests"].get.call_args_list == calls

assert tt_mock["subprocess"].run.call_count == 1
# pretrained model mode is the 12th arg to the training script, but subprocess
# pretrained model mode is the 13th arg to the training script, but subprocess
# is also given the script name - so we look for the expected pretrained model mode
# in the 13th arg of the subprocess.run call
assert (
tt_mock["subprocess"].run.call_args_list[0][0][0][12]
tt_mock["subprocess"].run.call_args_list[0][0][0][13]
== expected_pretrained_model_mode
)

0 comments on commit 9956ef2

Please sign in to comment.