From 714091207d9e593511741e85683a95e967792b69 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Tue, 21 Nov 2023 17:55:01 -0800 Subject: [PATCH] add async flag to CSVWriter Summary: Adds `async_write` flag to `CSVLogger` to make the `log()` call unblocking Reviewed By: galrotem Differential Revision: D51505159 --- tests/utils/loggers/test_csv.py | 16 +++++++++++++ torchtnt/utils/loggers/csv.py | 42 +++++++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/tests/utils/loggers/test_csv.py b/tests/utils/loggers/test_csv.py index b416a4eb3c..a13196c471 100644 --- a/tests/utils/loggers/test_csv.py +++ b/tests/utils/loggers/test_csv.py @@ -29,3 +29,19 @@ def test_csv_log(self) -> None: # pyre-fixme[16]: `_DictReadMapping` has no attribute `__getitem__`. self.assertEqual(float(output[0][log_name]), log_value) self.assertEqual(int(output[0]["step"]), log_step) + + def test_csv_log_async(self) -> None: + with TemporaryDirectory() as tmpdir: + csv_path = Path(tmpdir, "test.csv").as_posix() + logger = CSVLogger(csv_path, steps_before_flushing=1, async_write=True) + log_name = "asdf" + log_value = 123.0 + log_step = 10 + logger.log(log_name, log_value, log_step) + logger.close() + + with open(csv_path) as f: + output = list(csv.DictReader(f)) + # pyre-fixme[16]: `_DictReadMapping` has no attribute `__getitem__`. + self.assertEqual(float(output[0][log_name]), log_value) + self.assertEqual(int(output[0]["step"]), log_step) diff --git a/torchtnt/utils/loggers/csv.py b/torchtnt/utils/loggers/csv.py index 91b305ce9b..93a3fdaf35 100644 --- a/torchtnt/utils/loggers/csv.py +++ b/torchtnt/utils/loggers/csv.py @@ -7,6 +7,8 @@ import csv import logging +from threading import Thread +from typing import Dict, List, Optional from fsspec import open as fs_open from torchtnt.utils.loggers.file import FileLogger @@ -23,6 +25,7 @@ class CSVLogger(FileLogger, MetricLogger): path (str): path to write logs to steps_before_flushing: (int, optional): Number of steps to buffer in logger before flushing log_all_ranks: (bool, optional): Log all ranks if true, else log only on rank 0. + async_write: (bool, optional): Whether to write asynchronously or not. Defaults to False. """ def __init__( @@ -30,21 +33,40 @@ def __init__( path: str, steps_before_flushing: int = 100, log_all_ranks: bool = False, + async_write: bool = False, ) -> None: super().__init__(path, steps_before_flushing, log_all_ranks) - def flush(self) -> None: - data = self._log_buffer - if not data: - logger.debug("No logs to write.") - return + self._async_write = async_write + self._thread: Optional[Thread] = None + def flush(self) -> None: if self._rank == 0 or self._log_all_ranks: - with fs_open(self.path, "w") as f: - data_list = list(data.values()) - w = csv.DictWriter(f, data_list[0].keys()) - w.writeheader() - w.writerows(data_list) + buffer = self._log_buffer + if not buffer: + logger.debug("No logs to write.") + return + + if self._thread: + # ensure previous thread is completed before next write + self._thread.join() + + data_list = list(buffer.values()) + if not self._async_write: + _write_csv(self.path, data_list) + return + + self._thread = Thread(target=_write_csv, args=(self.path, data_list)) + self._thread.start() def close(self) -> None: + # toggle off async writing for final flush + self._async_write = False self.flush() + + +def _write_csv(path: str, data_list: List[Dict[str, float]]) -> None: + with fs_open(path, "w") as f: + w = csv.DictWriter(f, data_list[0].keys()) + w.writeheader() + w.writerows(data_list)