From 70eff55b82ebf088b55f444361f307441104b4eb Mon Sep 17 00:00:00 2001 From: Valentin Rigal Date: Tue, 3 Sep 2024 20:39:10 +0200 Subject: [PATCH] Publish Marian/OpusTrainer configuration YAMLs and dataset statistics (#720) * Publish Marian, OpusTrainer configs and datasets statistics * Update tests * Fixes * Fix tests in CI context * Nit * Store extra config files as new keys of the main config * Plot datasets in a custom chart * Fixes * Support extra-args for offline publication * Fix tests * Suggestions * TRASHME Test publication from CI * Revert "TRASHME Test publication from CI" This reverts commit 2da4a9a3cdfdaf84074f6614a389aa688b6399a2. * Suggestion * Trigger CI * Fix training and model key detection * TRASHME: Trigger publication from CI * Revert "TRASHME: Trigger publication from CI" This reverts commit ad4a3b7368a632a50b959feefb5bb1e756b24e52. --------- Co-authored-by: Evgeny Pavlov Co-authored-by: Evgeny Pavlov --- tests/test_tracking_cli.py | 65 ++++++++++-- tracking/translations_parser/parser.py | 112 ++++++++++++++++++++- tracking/translations_parser/publishers.py | 70 +++++++++---- tracking/translations_parser/utils.py | 9 +- 4 files changed, 216 insertions(+), 40 deletions(-) diff --git a/tests/test_tracking_cli.py b/tests/test_tracking_cli.py index f5a3c1546..4a3d33f78 100644 --- a/tests/test_tracking_cli.py +++ b/tests/test_tracking_cli.py @@ -20,12 +20,17 @@ def tmp_dir(): return Path(DataDir("test_tracking").path) -@pytest.fixture(autouse=True) +@pytest.fixture def disable_wandb(tmp_dir): """Prevent publication on W&B""" + environ = os.environ.copy() os.environ["WANDB_API_KEY"] = "fake" os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_DIR"] = str(tmp_dir / "wandb") + # Remove task ID to prevent publishing context data (training configuration, dataset) + os.environ.pop("TASK_ID", None) + yield + os.environ.update(environ) @pytest.fixture @@ -53,7 +58,7 @@ def samples_dir(): ), ) @patch("translations_parser.publishers.wandb") -def test_taskcluster(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir): +def test_taskcluster(wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir): caplog.set_level(logging.INFO) wandb_dir = tmp_dir / "wandb" wandb_dir.mkdir(parents=True) @@ -63,6 +68,11 @@ def test_taskcluster(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir): assert [(level, message) for _module, level, message in caplog.record_tuples] == [ (logging.INFO, "Reading logs stream."), (logging.INFO, "Detected Marian version 1.10"), + (logging.INFO, "Reading Marian command line arguments."), + ( + logging.INFO, + "Extra configuration files can only be retrieved in Taskcluster context, skipping.", + ), (logging.INFO, "Successfully parsed 1528 lines"), (logging.INFO, "Found 102 training entries"), (logging.INFO, "Found 34 validation entries"), @@ -74,7 +84,7 @@ def test_taskcluster(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir): c.kwargs["group"], c.kwargs["name"], c.kwargs["id"], - c.kwargs["config"].get("after"), + c.kwargs["config"].get("marian", {}).get("after"), ) for c in wandb_mock.init.call_args_list ] == [ @@ -92,7 +102,9 @@ def test_taskcluster(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir): return_value=argparse.Namespace(directory=Path(__file__).parent / "data" / "experiments_1_10"), ) @patch("translations_parser.publishers.wandb") -def test_experiments_marian_1_10(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir): +def test_experiments_marian_1_10( + wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir +): caplog.set_level(logging.INFO) wandb_dir = tmp_dir / "wandb" wandb_dir.mkdir(parents=True) @@ -112,6 +124,11 @@ def test_experiments_marian_1_10(wandb_mock, getargs_mock, caplog, samples_dir, ), (logging.INFO, "Reading logs stream."), (logging.INFO, "Detected Marian version 1.10"), + (logging.INFO, "Reading Marian command line arguments."), + ( + logging.INFO, + "Extra configuration files can only be retrieved in Taskcluster context, skipping.", + ), (logging.INFO, "Successfully parsed 1878 lines"), (logging.INFO, "Found 550 training entries"), (logging.INFO, "Found 108 validation entries"), @@ -155,7 +172,7 @@ def test_experiments_marian_1_10(wandb_mock, getargs_mock, caplog, samples_dir, c.kwargs["group"], c.kwargs["name"], c.kwargs["id"], - c.kwargs["config"].get("after"), + c.kwargs["config"].get("marian", {}).get("after"), ) for c in wandb_mock.init.call_args_list ] == [ @@ -211,7 +228,9 @@ def test_experiments_marian_1_10(wandb_mock, getargs_mock, caplog, samples_dir, return_value=argparse.Namespace(directory=Path(__file__).parent / "data" / "experiments_1_12"), ) @patch("translations_parser.publishers.wandb") -def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir): +def test_experiments_marian_1_12( + wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir +): caplog.set_level(logging.INFO) wandb_dir = tmp_dir / "wandb" wandb_dir.mkdir(parents=True) @@ -224,6 +243,12 @@ def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir, assert set([(level, message) for _module, level, message in caplog.record_tuples]) == set( [ (logging.INFO, "Reading 2 train.log data"), + (logging.INFO, "Detected Marian version 1.12"), + (logging.INFO, "Reading Marian command line arguments."), + ( + logging.INFO, + "Extra configuration files can only be retrieved in Taskcluster context, skipping.", + ), ( logging.INFO, f"Parsing folder {samples_dir}/experiments_1_12/models/fi-en/opusprod/student", @@ -247,7 +272,6 @@ def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir, (logging.INFO, "Found 4 quantized metrics from speed folder"), (logging.INFO, "Found 8 metrics from task logs"), (logging.INFO, "Creating missing run quantized with associated metrics"), - (logging.INFO, "Detected Marian version 1.12"), ] ) @@ -257,7 +281,7 @@ def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir, c.kwargs["group"], c.kwargs["name"], c.kwargs["id"], - c.kwargs["config"].get("after"), + c.kwargs["config"].get("marian", {}).get("after"), ) for c in wandb_mock.init.call_args_list ] == [ @@ -317,7 +341,7 @@ def test_experiments_marian_1_12(wandb_mock, getargs_mock, caplog, samples_dir, ) @patch("translations_parser.publishers.wandb") def test_taskcluster_wandb_initialization_failure( - wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir + wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir ): """ Ensures tracking continues despite W&B initialization failure @@ -328,6 +352,11 @@ def test_taskcluster_wandb_initialization_failure( assert [(level, message) for _module, level, message in caplog.record_tuples] == [ (logging.INFO, "Reading logs stream."), (logging.INFO, "Detected Marian version 1.10"), + (logging.INFO, "Reading Marian command line arguments."), + ( + logging.INFO, + "Extra configuration files can only be retrieved in Taskcluster context, skipping.", + ), ( logging.ERROR, "WandB client could not be initialized: Invalid credentials. No data will be published.", @@ -358,7 +387,9 @@ def test_taskcluster_wandb_initialization_failure( ), ) @patch("translations_parser.publishers.wandb") -def test_taskcluster_wandb_log_failures(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir): +def test_taskcluster_wandb_log_failures( + wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir +): """ Ensures tracking continues despite potential W&B data log failures """ @@ -372,6 +403,11 @@ def test_taskcluster_wandb_log_failures(wandb_mock, getargs_mock, caplog, sample assert [(level, message) for _module, level, message in caplog.record_tuples] == [ (logging.INFO, "Reading logs stream."), (logging.INFO, "Detected Marian version 1.10"), + (logging.INFO, "Reading Marian command line arguments."), + ( + logging.INFO, + "Extra configuration files can only be retrieved in Taskcluster context, skipping.", + ), ] + [ (logging.ERROR, "Error publishing training epoch using WandB: Unexpected failure"), (logging.ERROR, "Error publishing training epoch using WandB: Unexpected failure"), @@ -404,7 +440,9 @@ def test_taskcluster_wandb_log_failures(wandb_mock, getargs_mock, caplog, sample ), ) @patch("translations_parser.publishers.wandb") -def test_taskcluster_wandb_disabled(wandb_mock, getargs_mock, caplog, samples_dir, tmp_dir): +def test_taskcluster_wandb_disabled( + wandb_mock, getargs_mock, disable_wandb, caplog, samples_dir, tmp_dir +): """ Ensures tracking continues without Weight & Biases publication """ @@ -417,6 +455,11 @@ def test_taskcluster_wandb_disabled(wandb_mock, getargs_mock, caplog, samples_di ), (logging.INFO, "Reading logs stream."), (logging.INFO, "Detected Marian version 1.10"), + (logging.INFO, "Reading Marian command line arguments."), + ( + logging.INFO, + "Extra configuration files can only be retrieved in Taskcluster context, skipping.", + ), (logging.INFO, "Successfully parsed 1528 lines"), (logging.INFO, "Found 102 training entries"), (logging.INFO, "Found 34 validation entries"), diff --git a/tracking/translations_parser/parser.py b/tracking/translations_parser/parser.py index c908e49dc..dc4ee8aa7 100644 --- a/tracking/translations_parser/parser.py +++ b/tracking/translations_parser/parser.py @@ -1,16 +1,20 @@ import logging +import os import re +import shlex import sys from collections import defaultdict from collections.abc import Iterable, Iterator, Sequence from datetime import datetime from itertools import tee +from pathlib import Path from typing import Callable, DefaultDict, List import yaml from translations_parser.data import Metric, TrainingEpoch, TrainingLog, ValidationEpoch from translations_parser.publishers import Publisher +from translations_parser.utils import get_lines_count logger = logging.getLogger(__name__) @@ -36,6 +40,10 @@ # Expected version of Marian for a clean parsing SUPPORTED_MARIAN_VERSIONS = [(1, 10), (1, 12)] +MARIAN_ARGS_REGEX = re.compile(r"command line:[\n ]+[\w\/-]+\/marian +(.*)") +# Last Marian command line argument (not being part of training extra arguments) +LAST_MARIAN_DECLARED_ARGUMENT = "seed" + class TrainingParser: def __init__( @@ -215,6 +223,97 @@ def _join(seq): yield headers, text + def get_extra_marian_config(self) -> dict: + """ + Read extra configuration files (Marian, OpusTrainer, extra CLI arguments). + Publication outside of a Taskcluster context (offline mode) cannot access + the configuration files, only extra-args will be set in this case. + """ + extra_config = { + "arguments": None, + "model": None, + "training": None, + "datasets": None, + "opustrainer": None, + } + + if ( + self.description is None + or (match := MARIAN_ARGS_REGEX.search(self.description)) is None + ): + logger.error(self.description) + logger.warning( + "Invalid Marian description, skipping Marian and OpusTrainer configuration detection." + ) + return extra_config + + logger.info("Reading Marian command line arguments.") + (arguments_str,) = match.groups() + # Build args from the command line input text + args = defaultdict(list) + key = None + for i in iter(shlex.split(arguments_str)): + if i.startswith("-"): + key = i.strip("-") + continue + args[key].append(i) + + # Store arguments used to run Marian, flattening single values + def flatten(vals): + if not vals: + return "" + elif len(vals) == 1: + return vals[0] + return vals + + extra_config["arguments"] = {k: flatten(v) for k, v in args.items()} + + if os.environ.get("TASK_ID") is None: + logger.info( + "Extra configuration files can only be retrieved in Taskcluster context, skipping." + ) + return extra_config + + # Handle Marian model and training YAML configuration files (called as --config or -c) + for path in args.get("config", args["c"]): + if path.startswith("configs/training"): + key = "training" + elif path.startswith("configs/model"): + key = "model" + else: + continue + try: + with open(path, "r") as f: + extra_config[key] = yaml.safe_load(f.read()) + except Exception as e: + logger.warning(f"Impossible to parse Marian {key} config at {path}: {e}") + + # Handle OpusTrainer configuration + (model_path,) = args.get("model", ("./model.npz",)) + model_dir = Path(model_path).parent + train_conf_path = (model_dir / "config.opustrainer.yml").resolve() + if not train_conf_path.exists(): + logger.warning(f"OpusTrainer configuration file does not exists at {train_conf_path}.") + else: + try: + with open(train_conf_path, "r") as f: + extra_config["opustrainer"] = yaml.safe_load(f.read()) + except Exception as e: + logger.warning(f"Impossible to parse OpusTrainer config at {train_conf_path}: {e}") + else: + logger.info("Reading datasets statistics from OpusTrainer configuration.") + try: + dataset_conf = extra_config.get("opustrainer", {}).get("datasets", {}) + extra_config["datasets"] = { + key: get_lines_count(path) for key, path in dataset_conf.items() + } + except Exception as e: + logger.warning( + f"OpusTrainer configuration could not be read at {train_conf_path}: {e}." + ) + + return extra_config + def parse_marian_context(self, logs_iter: Iterator[tuple[list[tuple[str]], str]]) -> None: """ Looks for Marian context in the first logs lines. @@ -231,6 +330,7 @@ def parse_marian_context(self, logs_iter: Iterator[tuple[list[tuple[str]], str]] version = version.rstrip(";") major, minor = map(int, version.lstrip("v").split(".")[:2]) self.version = f"{major}.{minor}" + logger.info(f"Detected Marian version {self.version}") if (major, minor) not in SUPPORTED_MARIAN_VERSIONS: versions = ", ".join(f"{major}.{minor}" for major, minor in SUPPORTED_MARIAN_VERSIONS) logger.warning( @@ -241,7 +341,8 @@ def parse_marian_context(self, logs_iter: Iterator[tuple[list[tuple[str]], str]] logger.debug("Reading Marian run description.") desc = [] for headers, text in logs_iter: - if ("marian",) not in headers: + # Marian headers stops when dumping the configuration + if ("config",) in headers: break desc.append(text) self.description = " ".join(desc) @@ -257,11 +358,12 @@ def parse_marian_context(self, logs_iter: Iterator[tuple[list[tuple[str]], str]] config_yaml += f"{text}\n" headers, text = next(logs_iter) try: - self.config = yaml.safe_load(config_yaml) + self.config["marian"] = yaml.safe_load(config_yaml) except Exception as e: - raise Exception(f"Invalid config section: {e}") + logger.error(f"Impossible to parse Marian config YAML: {e}") - logger.info(f"Detected Marian version {self.version}") + # Try to read required extra configuration files when running online from Taskcluster + self.config.update(self.get_extra_marian_config()) def parse_data(self, logs_iter: Iterator[tuple[list[tuple[str]], str]]) -> None: """ @@ -310,7 +412,7 @@ def parse(self) -> None: # Once all data has been parsed, call the final publication API for publisher in self.publishers: try: - publisher.publish(self.output) + publisher.publish() # Publish optional metrics if self.metrics: publisher.handle_metrics(self.metrics) diff --git a/tracking/translations_parser/publishers.py b/tracking/translations_parser/publishers.py index da16ee49e..6cc6e1104 100644 --- a/tracking/translations_parser/publishers.py +++ b/tracking/translations_parser/publishers.py @@ -10,7 +10,7 @@ import wandb import yaml -from translations_parser.data import Metric, TrainingEpoch, TrainingLog, ValidationEpoch +from translations_parser.data import Metric, TrainingEpoch, ValidationEpoch from translations_parser.utils import parse_task_label logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ def handle_validation(self, validation: ValidationEpoch) -> None: def handle_metrics(self, metrics: Sequence[Metric]) -> None: ... - def publish(self, log: TrainingLog) -> None: + def publish(self) -> None: ... def close(self) -> None: @@ -48,9 +48,15 @@ def close(self) -> None: class CSVExport(Publisher): def __init__(self, output_dir: Path) -> None: + from translations_parser.parser import TrainingParser + if not output_dir.is_dir(): raise ValueError("Output must be a valid directory for the CSV export") self.output_dir = output_dir + self.parser: TrainingParser | None = None + + def open(self, parser=None) -> None: + self.parser = parser def write_data( self, output: Path, entries: Sequence[TrainingEpoch | ValidationEpoch], dataclass: type @@ -63,7 +69,9 @@ def write_data( for entry in entries: writer.writerow(vars(entry)) - def publish(self, training_log: TrainingLog) -> None: + def publish(self) -> None: + assert self.parser is not None, "Parser must be set to run CSV publication." + training_log = self.parser.output training_output = self.output_dir / "training.csv" if training_output.exists(): logger.warning(f"Training output file {training_output} exists, skipping.") @@ -109,10 +117,30 @@ def __init__( self.parser: TrainingParser | None = None self.wandb: wandb.sdk.wandb_run.Run | wandb.sdk.lib.disabled.RunDisabled | None = None + def close(self) -> None: + if self.wandb is None: + return + + # Publish artifacts + if self.artifacts: + artifact = wandb.Artifact(name=self.artifacts_name, type=self.artifacts_name) + artifact.add_dir(local_path=str(self.artifacts.resolve())) + self.wandb.log_artifact(artifact) + + if self.parser is not None: + # Store Marian logs as the main log artifact, instead of W&B client runtime. + # This will be overwritten in case an unhandled exception occurs. + for line in self.parser.parsed_logs: + sys.stdout.write(f"{line}\n") + + self.wandb.finish() + def open(self, parser=None, resume: bool = False) -> None: self.parser = parser - config = getattr(parser, "config", {}) + config = getattr(parser, "config", {}).copy() config.update(self.extra_kwargs.pop("config", {})) + # Publish datasets stats directly in the dashboard + datasets = config.pop("datasets", None) # Avoid overriding an existing run on a first training, this should not happen if resume is False and int(os.environ.get("RUN_ID", 0)) > 0: @@ -136,6 +164,22 @@ def open(self, parser=None, resume: bool = False) -> None: except Exception as e: logger.error(f"WandB client could not be initialized: {e}. No data will be published.") + if datasets is not None: + # Log dataset sizes as a custom bar chart + self.wandb.log( + { + "Datasets": wandb.plot.bar( + wandb.Table( + columns=["Name", "Count"], + data=[[key, value] for key, value in datasets.items()], + ), + "Name", + "Count", + title="Datasets", + ) + } + ) + def generic_log(self, data: TrainingEpoch | ValidationEpoch) -> None: if self.wandb is None: return @@ -181,24 +225,6 @@ def handle_metrics(self, metrics: Sequence[Metric]) -> None: } ) - def close(self) -> None: - if self.wandb is None: - return - - # Publish artifacts - if self.artifacts: - artifact = wandb.Artifact(name=self.artifacts_name, type=self.artifacts_name) - artifact.add_dir(local_path=str(self.artifacts.resolve())) - self.wandb.log_artifact(artifact) - - if self.parser is not None: - # Store Marian logs as the main log artifact, instead of W&B client runtime. - # This will be overwritten in case an unhandled exception occurs. - for line in self.parser.parsed_logs: - sys.stdout.write(f"{line}\n") - - self.wandb.finish() - @classmethod def publish_group_logs( cls, diff --git a/tracking/translations_parser/utils.py b/tracking/translations_parser/utils.py index 16bb2902c..80019bf27 100644 --- a/tracking/translations_parser/utils.py +++ b/tracking/translations_parser/utils.py @@ -229,8 +229,8 @@ def publish_group_logs_from_tasks( for metric_task_id, metrics_task in metrics_tasks.items(): filename = metrics_task["task"]["tags"]["label"] if re_match := MULTIPLE_TRAIN_SUFFIX.search(filename): - (suffix,) = re_match.groups() - filename = MULTIPLE_TRAIN_SUFFIX.sub(suffix, filename) + (train_suffix,) = re_match.groups() + filename = MULTIPLE_TRAIN_SUFFIX.sub(train_suffix, filename) metric_artifact = next( ( @@ -275,3 +275,8 @@ def suffix_from_group(task_group_id: str) -> str: len(task_group_id) >= 5 ), f"Taskcluster group ID should contain more than 5 characters: {task_group_id}" return f"_{task_group_id[:5]}" + + +def get_lines_count(file_path: str) -> int: + with open(file_path, "r") as f: + return sum(1 for _ in f)