Skip to content

Commit

Permalink
fix: don't run evaluate tasks on pretrained models (#781)
Browse files Browse the repository at this point in the history
This is accomplished with a new transform used by the `evaluate` tasks, which avoids yielding any tasks whose `stage` matches a `pretrained-models` stage.
  • Loading branch information
bhearsum authored Sep 4, 2024
1 parent 92c2b45 commit fe815e1
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 0 deletions.
1 change: 1 addition & 0 deletions taskcluster/kinds/evaluate-teacher-ensemble/kind.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
loader: taskgraph.loader.transform:loader

transforms:
- translations_taskgraph.transforms.training_continuation:evaluate_stage
- translations_taskgraph.transforms.from_datasets:per_dataset
- translations_taskgraph.transforms.worker_selection
- taskgraph.transforms.from_deps
Expand Down
1 change: 1 addition & 0 deletions taskcluster/kinds/evaluate/kind.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
loader: taskgraph.loader.transform:loader

transforms:
- translations_taskgraph.transforms.training_continuation:evaluate_stage
- translations_taskgraph.transforms.from_datasets:per_dataset
- translations_taskgraph.transforms.worker_selection
- taskgraph.transforms.task_context
Expand Down
24 changes: 24 additions & 0 deletions taskcluster/test/test_training_continuation_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@
},
}

MOCK_REQUESTS = [
{
"method": "POST",
"url": "https://firefox-ci-tc.services.mozilla.com/api/index/v1/tasks/indexes",
"responses": [{"json": {"tasks": []}}],
},
{
"method": "POST",
"url": "https://firefox-ci-tc.services.mozilla.com/api/queue/v1/tasks/status",
"responses": [{"json": {"statuses": []}}],
},
]


def test_artifact_mounts(full_task_graph: TaskGraph):
task = [t for t in full_task_graph.tasks.values() if t.label == "train-backwards-ru-en"][0]
Expand All @@ -26,3 +39,14 @@ def test_artifact_mounts(full_task_graph: TaskGraph):
assert mounted_files["./artifacts/model.npz"]["content"] == {
"url": "https://storage.googleapis.com/releng-translations-dev/models/ru-en/better-teacher/student/model.npz",
}


def test_no_eval_tasks(optimized_task_graph: TaskGraph):
"""Ensure evaluate tasks for train-backwards aren't targeted.
See https://github.com/mozilla/firefox-translations-training/issues/628"""
eval_tasks = [
task.label
for task in optimized_task_graph.tasks.values()
if task.label.startswith("evaluate-backward")
]
assert len(eval_tasks) == 0
52 changes: 52 additions & 0 deletions taskcluster/test/test_training_continuation_teacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from copy import deepcopy

from taskgraph.taskgraph import TaskGraph

from translations_taskgraph.parameters import get_defaults

PARAMS = deepcopy(get_defaults(None))
PARAMS["target_tasks_method"] = "train-target-tasks"
PARAMS["training_config"]["experiment"]["pretrained-models"] = {
"train-teacher": {
"mode": "use",
"type": "default",
"urls": [
"https://storage.googleapis.com/releng-translations-dev/models/ru-en/better-teacher/student"
],
},
}

MOCK_REQUESTS = [
{
"method": "POST",
"url": "https://firefox-ci-tc.services.mozilla.com/api/index/v1/tasks/indexes",
"responses": [{"json": {"tasks": []}}],
},
{
"method": "POST",
"url": "https://firefox-ci-tc.services.mozilla.com/api/queue/v1/tasks/status",
"responses": [{"json": {"statuses": []}}],
},
]


def test_artifact_mounts(full_task_graph: TaskGraph):
task = [t for t in full_task_graph.tasks.values() if t.label == "train-teacher-ru-en-1"][0]
# No need to bother looking for _all_ files (we'd just duplicate
# the full list if we did that...), but we verify that one file
# is well formed.
mounted_files = {m["file"]: m for m in task.task["payload"]["mounts"] if "file" in m}
assert mounted_files["./artifacts/model.npz"]["content"] == {
"url": "https://storage.googleapis.com/releng-translations-dev/models/ru-en/better-teacher/student/model.npz",
}


def test_no_eval_tasks(optimized_task_graph: TaskGraph):
"""Ensure evaluate tasks for train-teacher aren't targeted.
See https://github.com/mozilla/firefox-translations-training/issues/628"""
eval_tasks = [
task.label
for task in optimized_task_graph.tasks.values()
if task.label.startswith("evaluate-teacher")
]
assert len(eval_tasks) == 0
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,24 @@ def add_pretrained_model_mounts(config, jobs):
}

yield job


evaluate_stage = TransformSequence()


@evaluate_stage.add
def skip_for_pretrained_models(config, jobs):
# Find the types of pretrained models that are being used. This makes
# it easier to filter them out in the loop below.
pretrained_models = [
pretrained.split("-")[-1].replace("backwards", "backward")
for pretrained in config.params["training_config"]["experiment"]
.get("pretrained-models", {})
.keys()
]

for job in jobs:
if any([pretrained in job["attributes"]["stage"] for pretrained in pretrained_models]):
continue

yield job

0 comments on commit fe815e1

Please sign in to comment.