Skip to content

Commit

Permalink
Publish Marian/OpusTrainer configuration YAMLs and dataset statistics (
Browse files Browse the repository at this point in the history
…#720)

* Publish Marian, OpusTrainer configs and datasets statistics

* Update tests

* Fixes

* Fix tests in CI context

* Nit

* Store extra config files as new keys of the main config

* Plot datasets in a custom chart

* Fixes

* Support extra-args for offline publication

* Fix tests

* Suggestions

* TRASHME Test publication from CI

* Revert "TRASHME Test publication from CI"

This reverts commit 2da4a9a.

* Suggestion

* Trigger CI

* Fix training and model key detection

* TRASHME: Trigger publication from CI

* Revert "TRASHME: Trigger publication from CI"

This reverts commit ad4a3b7.

---------

Co-authored-by: Evgeny Pavlov <[email protected]>
Co-authored-by: Evgeny Pavlov <[email protected]>
  • Loading branch information
3 people authored Sep 3, 2024
1 parent f7247a6 commit 70eff55
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 40 deletions.
65 changes: 54 additions & 11 deletions tests/test_tracking_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@ def tmp_dir():
return Path(DataDir("test_tracking").path)


@pytest.fixture(autouse=True)
@pytest.fixture
def disable_wandb(tmp_dir):
"""Prevent publication on W&B"""
environ = os.environ.copy()
os.environ["WANDB_API_KEY"] = "fake"
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DIR"] = str(tmp_dir / "wandb")
# Remove task ID to prevent publishing context data (training configuration, dataset)
os.environ.pop("TASK_ID", None)
yield
os.environ.update(environ)


@pytest.fixture
Expand Down Expand Up @@ -53,7 +58,7 @@ def samples_dir():
),
)
@patch("translations_parser.publishers.wandb")
def test_taskcluster(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir):
def test_taskcluster(wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir):
caplog.set_level(logging.INFO)
wandb_dir = tmp_dir / "wandb"
wandb_dir.mkdir(parents=True)
Expand All @@ -63,6 +68,11 @@ def test_taskcluster(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir):
assert [(level, message) for _module, level, message in caplog.record_tuples] == [
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Detected Marian version 1.10"),
(logging.INFO, "Reading Marian command line arguments."),
(
logging.INFO,
"Extra configuration files can only be retrieved in Taskcluster context, skipping.",
),
(logging.INFO, "Successfully parsed 1528 lines"),
(logging.INFO, "Found 102 training entries"),
(logging.INFO, "Found 34 validation entries"),
Expand All @@ -74,7 +84,7 @@ def test_taskcluster(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir):
c.kwargs["group"],
c.kwargs["name"],
c.kwargs["id"],
c.kwargs["config"].get("after"),
c.kwargs["config"].get("marian", {}).get("after"),
)
for c in wandb_mock.init.call_args_list
] == [
Expand All @@ -92,7 +102,9 @@ def test_taskcluster(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir):
return_value=argparse.Namespace(directory=Path(__file__).parent / "data" / "experiments_1_10"),
)
@patch("translations_parser.publishers.wandb")
def test_experiments_marian_1_10(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir):
def test_experiments_marian_1_10(
wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir
):
caplog.set_level(logging.INFO)
wandb_dir = tmp_dir / "wandb"
wandb_dir.mkdir(parents=True)
Expand All @@ -112,6 +124,11 @@ def test_experiments_marian_1_10(wandb_mock, getargs_mock, caplog, samples_dir,
),
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Detected Marian version 1.10"),
(logging.INFO, "Reading Marian command line arguments."),
(
logging.INFO,
"Extra configuration files can only be retrieved in Taskcluster context, skipping.",
),
(logging.INFO, "Successfully parsed 1878 lines"),
(logging.INFO, "Found 550 training entries"),
(logging.INFO, "Found 108 validation entries"),
Expand Down Expand Up @@ -155,7 +172,7 @@ def test_experiments_marian_1_10(wandb_mock, getargs_mock, caplog, samples_dir,
c.kwargs["group"],
c.kwargs["name"],
c.kwargs["id"],
c.kwargs["config"].get("after"),
c.kwargs["config"].get("marian", {}).get("after"),
)
for c in wandb_mock.init.call_args_list
] == [
Expand Down Expand Up @@ -211,7 +228,9 @@ def test_experiments_marian_1_10(wandb_mock, getargs_mock, caplog, samples_dir,
return_value=argparse.Namespace(directory=Path(__file__).parent / "data" / "experiments_1_12"),
)
@patch("translations_parser.publishers.wandb")
def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir):
def test_experiments_marian_1_12(
wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir
):
caplog.set_level(logging.INFO)
wandb_dir = tmp_dir / "wandb"
wandb_dir.mkdir(parents=True)
Expand All @@ -224,6 +243,12 @@ def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir,
assert set([(level, message) for _module, level, message in caplog.record_tuples]) == set(
[
(logging.INFO, "Reading 2 train.log data"),
(logging.INFO, "Detected Marian version 1.12"),
(logging.INFO, "Reading Marian command line arguments."),
(
logging.INFO,
"Extra configuration files can only be retrieved in Taskcluster context, skipping.",
),
(
logging.INFO,
f"Parsing folder {samples_dir}/experiments_1_12/models/fi-en/opusprod/student",
Expand All @@ -247,7 +272,6 @@ def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir,
(logging.INFO, "Found 4 quantized metrics from speed folder"),
(logging.INFO, "Found 8 metrics from task logs"),
(logging.INFO, "Creating missing run quantized with associated metrics"),
(logging.INFO, "Detected Marian version 1.12"),
]
)

Expand All @@ -257,7 +281,7 @@ def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir,
c.kwargs["group"],
c.kwargs["name"],
c.kwargs["id"],
c.kwargs["config"].get("after"),
c.kwargs["config"].get("marian", {}).get("after"),
)
for c in wandb_mock.init.call_args_list
] == [
Expand Down Expand Up @@ -317,7 +341,7 @@ def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir,
)
@patch("translations_parser.publishers.wandb")
def test_taskcluster_wandb_initialization_failure(
wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir
wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir
):
"""
Ensures tracking continues despite W&B initialization failure
Expand All @@ -328,6 +352,11 @@ def test_taskcluster_wandb_initialization_failure(
assert [(level, message) for _module, level, message in caplog.record_tuples] == [
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Detected Marian version 1.10"),
(logging.INFO, "Reading Marian command line arguments."),
(
logging.INFO,
"Extra configuration files can only be retrieved in Taskcluster context, skipping.",
),
(
logging.ERROR,
"WandB client could not be initialized: Invalid credentials. No data will be published.",
Expand Down Expand Up @@ -358,7 +387,9 @@ def test_taskcluster_wandb_initialization_failure(
),
)
@patch("translations_parser.publishers.wandb")
def test_taskcluster_wandb_log_failures(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir):
def test_taskcluster_wandb_log_failures(
wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir
):
"""
Ensures tracking continues despite potential W&B data log failures
"""
Expand All @@ -372,6 +403,11 @@ def test_taskcluster_wandb_log_failures(wandb_mock, getargs_mock, caplog, sample
assert [(level, message) for _module, level, message in caplog.record_tuples] == [
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Detected Marian version 1.10"),
(logging.INFO, "Reading Marian command line arguments."),
(
logging.INFO,
"Extra configuration files can only be retrieved in Taskcluster context, skipping.",
),
] + [
(logging.ERROR, "Error publishing training epoch using WandB: Unexpected failure"),
(logging.ERROR, "Error publishing training epoch using WandB: Unexpected failure"),
Expand Down Expand Up @@ -404,7 +440,9 @@ def test_taskcluster_wandb_log_failures(wandb_mock, getargs_mock, caplog, sample
),
)
@patch("translations_parser.publishers.wandb")
def test_taskcluster_wandb_disabled(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir):
def test_taskcluster_wandb_disabled(
wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir
):
"""
Ensures tracking continues without Weight & Biases publication
"""
Expand All @@ -417,6 +455,11 @@ def test_taskcluster_wandb_disabled(wandb_mock, getargs_mock, caplog, samples_di
),
(logging.INFO, "Reading logs stream."),
(logging.INFO, "Detected Marian version 1.10"),
(logging.INFO, "Reading Marian command line arguments."),
(
logging.INFO,
"Extra configuration files can only be retrieved in Taskcluster context, skipping.",
),
(logging.INFO, "Successfully parsed 1528 lines"),
(logging.INFO, "Found 102 training entries"),
(logging.INFO, "Found 34 validation entries"),
Expand Down
112 changes: 107 additions & 5 deletions tracking/translations_parser/parser.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import logging
import os
import re
import shlex
import sys
from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from datetime import datetime
from itertools import tee
from pathlib import Path
from typing import Callable, DefaultDict, List

import yaml

from translations_parser.data import Metric, TrainingEpoch, TrainingLog, ValidationEpoch
from translations_parser.publishers import Publisher
from translations_parser.utils import get_lines_count

logger = logging.getLogger(__name__)

Expand All @@ -36,6 +40,10 @@
# Expected version of Marian for a clean parsing
SUPPORTED_MARIAN_VERSIONS = [(1, 10), (1, 12)]

MARIAN_ARGS_REGEX = re.compile(r"command line:[\n ]+[\w\/-]+\/marian +(.*)")
# Last Marian command line argument (not being part of training extra arguments)
LAST_MARIAN_DECLARED_ARGUMENT = "seed"


class TrainingParser:
def __init__(
Expand Down Expand Up @@ -215,6 +223,97 @@ def _join(seq):

yield headers, text

def get_extra_marian_config(self) -> dict:
"""
Read extra configuration files (Marian, OpusTrainer, extra CLI arguments).
Publication outside of a Taskcluster context (offline mode) cannot access
the configuration files, only extra-args will be set in this case.
"""
extra_config = {
"arguments": None,
"model": None,
"training": None,
"datasets": None,
"opustrainer": None,
}

if (
self.description is None
or (match := MARIAN_ARGS_REGEX.search(self.description)) is None
):
logger.error(self.description)
logger.warning(
"Invalid Marian description, skipping Marian and OpusTrainer configuration detection."
)
return extra_config

logger.info("Reading Marian command line arguments.")
(arguments_str,) = match.groups()
# Build args from the command line input text
args = defaultdict(list)
key = None
for i in iter(shlex.split(arguments_str)):
if i.startswith("-"):
key = i.strip("-")
continue
args[key].append(i)

# Store arguments used to run Marian, flattening single values
def flatten(vals):
if not vals:
return ""
elif len(vals) == 1:
return vals[0]
return vals

extra_config["arguments"] = {k: flatten(v) for k, v in args.items()}

if os.environ.get("TASK_ID") is None:
logger.info(
"Extra configuration files can only be retrieved in Taskcluster context, skipping."
)
return extra_config

# Handle Marian model and training YAML configuration files (called as --config or -c)
for path in args.get("config", args["c"]):
if path.startswith("configs/training"):
key = "training"
elif path.startswith("configs/model"):
key = "model"
else:
continue
try:
with open(path, "r") as f:
extra_config[key] = yaml.safe_load(f.read())
except Exception as e:
logger.warning(f"Impossible to parse Marian {key} config at {path}: {e}")

# Handle OpusTrainer configuration
(model_path,) = args.get("model", ("./model.npz",))
model_dir = Path(model_path).parent
train_conf_path = (model_dir / "config.opustrainer.yml").resolve()
if not train_conf_path.exists():
logger.warning(f"OpusTrainer configuration file does not exists at {train_conf_path}.")
else:
try:
with open(train_conf_path, "r") as f:
extra_config["opustrainer"] = yaml.safe_load(f.read())
except Exception as e:
logger.warning(f"Impossible to parse OpusTrainer config at {train_conf_path}: {e}")
else:
logger.info("Reading datasets statistics from OpusTrainer configuration.")
try:
dataset_conf = extra_config.get("opustrainer", {}).get("datasets", {})
extra_config["datasets"] = {
key: get_lines_count(path) for key, path in dataset_conf.items()
}
except Exception as e:
logger.warning(
f"OpusTrainer configuration could not be read at {train_conf_path}: {e}."
)

return extra_config

def parse_marian_context(self, logs_iter: Iterator[tuple[list[tuple[str]], str]]) -> None:
"""
Looks for Marian context in the first logs lines.
Expand All @@ -231,6 +330,7 @@ def parse_marian_context(self, logs_iter: Iterator[tuple[list[tuple[str]], str]]
version = version.rstrip(";")
major, minor = map(int, version.lstrip("v").split(".")[:2])
self.version = f"{major}.{minor}"
logger.info(f"Detected Marian version {self.version}")
if (major, minor) not in SUPPORTED_MARIAN_VERSIONS:
versions = ", ".join(f"{major}.{minor}" for major, minor in SUPPORTED_MARIAN_VERSIONS)
logger.warning(
Expand All @@ -241,7 +341,8 @@ def parse_marian_context(self, logs_iter: Iterator[tuple[list[tuple[str]], str]]
logger.debug("Reading Marian run description.")
desc = []
for headers, text in logs_iter:
if ("marian",) not in headers:
# Marian headers stops when dumping the configuration
if ("config",) in headers:
break
desc.append(text)
self.description = " ".join(desc)
Expand All @@ -257,11 +358,12 @@ def parse_marian_context(self, logs_iter: Iterator[tuple[list[tuple[str]], str]]
config_yaml += f"{text}\n"
headers, text = next(logs_iter)
try:
self.config = yaml.safe_load(config_yaml)
self.config["marian"] = yaml.safe_load(config_yaml)
except Exception as e:
raise Exception(f"Invalid config section: {e}")
logger.error(f"Impossible to parse Marian config YAML: {e}")

logger.info(f"Detected Marian version {self.version}")
# Try to read required extra configuration files when running online from Taskcluster
self.config.update(self.get_extra_marian_config())

def parse_data(self, logs_iter: Iterator[tuple[list[tuple[str]], str]]) -> None:
"""
Expand Down Expand Up @@ -310,7 +412,7 @@ def parse(self) -> None:
# Once all data has been parsed, call the final publication API
for publisher in self.publishers:
try:
publisher.publish(self.output)
publisher.publish()
# Publish optional metrics
if self.metrics:
publisher.handle_metrics(self.metrics)
Expand Down
Loading

0 comments on commit 70eff55

Please sign in to comment.