Skip to content
16 changes: 14 additions & 2 deletions ndsl/monitor/netcdf_monitor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from warnings import warn

import fsspec
import numpy as np

from ndsl.comm.communicator import Communicator
from ndsl.dsl.typing import Float, get_precision
from ndsl.filesystem import get_fs
from ndsl.logging import ndsl_log
from ndsl.monitor.convert import to_numpy
Expand Down Expand Up @@ -114,6 +116,7 @@ def __init__(
path: str,
communicator: Communicator,
time_chunk_size: int = 1,
precision=Float,
):
"""Create a NetCDFMonitor.

Expand All @@ -130,6 +133,11 @@ def __init__(
self._time_chunk_size = time_chunk_size
self.__writer: Optional[_ChunkedNetCDFWriter] = None
self._expected_vars: Optional[Set[str]] = None
self._transfer_type = precision
if self._transfer_type == np.float32 and get_precision() > 32:
warn(
f"NetCDF save: requested 32-bit float but precision of NDSL is {get_precision()}, cast will occur with possible loss of precision"
)

@property
def _writer(self):
Expand Down Expand Up @@ -164,12 +172,16 @@ def store(self, state: dict) -> None:
set(state.keys()), self._expected_vars
)
)
state = self._communicator.tile.gather_state(state, transfer_type=np.float32)
state = self._communicator.tile.gather_state(
state, transfer_type=self._transfer_type
)
if state is not None: # we are on root rank
self._writer.append(state)

def store_constant(self, state: Dict[str, Quantity]) -> None:
state = self._communicator.gather_state(state, transfer_type=np.float32)
state = self._communicator.gather_state(
state, transfer_type=self._transfer_type
)
if state is not None: # we are on root rank
constants_filename = str(
Path(self._path) / NetCDFMonitor._CONSTANT_FILENAME
Expand Down