import gcsfs import jax import numpy as np import pickle import xarray from dinosaur import horizontal_interpolation from dinosaur import spherical_harmonic from dinosaur import xarray_utils import neuralgcm gcs = gcsfs.GCSFileSystem(token='anon') model_name = 'neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl' with gcs.open(f'gs://gresearch/neuralgcm/04_30_2024/{model_name}', 'rb') as f: ckpt = pickle.load(f) model = neuralgcm.PressureLevelModel.from_checkpoint(ckpt) era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' full_era5 = xarray.open_zarr(gcs.get_mapper(era5_path), chunks=None) demo_start_time = '2023-05-15' demo_end_time = '2023-05-18' data_inner_steps = 12 # process every 12th hour sliced_era5 = ( full_era5 [model.input_variables + model.forcing_variables] .pipe( xarray_utils.selective_temporal_shift, variables=model.forcing_variables, time_shift='24 hours', ) .sel(time=slice(demo_start_time, demo_end_time, data_inner_steps)) .compute() ) era5_grid = spherical_harmonic.Grid( latitude_nodes=full_era5.sizes['latitude'], longitude_nodes=full_era5.sizes['longitude'], latitude_spacing=xarray_utils.infer_latitude_spacing(full_era5.latitude), longitude_offset=xarray_utils.infer_longitude_offset(full_era5.longitude), ) regridder = horizontal_interpolation.ConservativeRegridder( era5_grid, model.data_coords.horizontal, skipna=True ) eval_era5 = xarray_utils.regrid(sliced_era5, regridder) eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5) inner_steps = 12 # save model outputs once every 12 hours outer_steps = 3 * 24 // inner_steps # total of 3 days timedelta = np.timedelta64(1, 'h') * inner_steps times = (np.arange(outer_steps) * inner_steps) # time axis in hours # initialize model state inputs = model.inputs_from_xarray(eval_era5.isel(time=0)) input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0)) rng_key = jax.random.key(42) # optional for deterministic models initial_state = model.encode(inputs, input_forcings, rng_key) # use persistence for forcing variables (SST and sea ice cover) all_forcings = model.forcings_from_xarray(eval_era5.head(time=1)) # make forecast final_state, predictions = model.unroll( initial_state, all_forcings, steps=outer_steps, timedelta=timedelta, start_with_input=True, ) predictions_ds = model.data_to_xarray(predictions, times=times)