1
- import os
2
1
import logging
3
2
from collections import OrderedDict
4
3
from sys import stdout
5
- from sys import path
4
+ from typing import Optional
6
5
7
6
import numpy as np
8
- import h5py
7
+ import h5py #type: ignore
9
8
from mpi4py import MPI
10
- from scipy .interpolate import RegularGridInterpolator
9
+ from scipy .interpolate import RegularGridInterpolator #type: ignore
11
10
import matplotlib
12
11
matplotlib .use ('Agg' )
13
- import matplotlib .pyplot as plt
14
12
matplotlib .rcParams .update ({'font.size' : 9 })
15
13
16
- from dedalus .tools .parallel import Sync
17
-
18
14
from plotpal .file_reader import SingleTypeReader , match_basis
19
15
from plotpal .plot_grid import RegularPlotGrid
20
16
@@ -30,16 +26,36 @@ class PdfPlotter(SingleTypeReader):
30
26
that basis is evenly interpolated to avoid skewing of the distribution by uneven grid sampling.
31
27
"""
32
28
33
- def __init__ (self , * args , ** kwargs ):
29
+ def __init__ (
30
+ self ,
31
+ run_dir : str ,
32
+ sub_dir : str ,
33
+ out_name : str ,
34
+ distribution : str = 'even-write' ,
35
+ num_files : Optional [int ] = None ,
36
+ roll_writes : Optional [int ] = None ,
37
+ start_file : int = 1 ,
38
+ global_comm : MPI .Intracomm = MPI .COMM_WORLD ,
39
+ chunk_size : int = 1000
40
+ ):
34
41
"""
35
42
Initializes the PDF plotter.
36
43
"""
37
- super (PdfPlotter , self ).__init__ (* args , distribution = 'even-write' , ** kwargs )
38
- self .pdfs = OrderedDict ()
39
- self .pdf_stats = OrderedDict ()
40
-
41
-
42
- def _calculate_pdf_statistics (self ):
44
+ super (PdfPlotter , self ).__init__ (
45
+ run_dir = run_dir ,
46
+ sub_dir = sub_dir ,
47
+ out_name = out_name ,
48
+ distribution = distribution ,
49
+ num_files = num_files ,
50
+ roll_writes = roll_writes ,
51
+ start_file = start_file ,
52
+ global_comm = global_comm ,
53
+ chunk_size = chunk_size
54
+ )
55
+ self .pdfs : dict [str , tuple [np .ndarray , np .ndarray , np .ndarray ]] = OrderedDict ()
56
+ self .pdf_stats : dict [str , tuple [float , float , float , float ]] = OrderedDict ()
57
+
58
+ def _calculate_pdf_statistics (self ) -> None :
43
59
""" Calculate statistics of the PDFs stored in self.pdfs. Store results in self.pdf_stats. """
44
60
for k , data in self .pdfs .items ():
45
61
pdf , x_vals , dx = data
@@ -51,17 +67,19 @@ def _calculate_pdf_statistics(self):
51
67
self .pdf_stats [k ] = (mean , stdev , skew , kurt )
52
68
53
69
54
- def _get_interpolated_slices (self , dsets , ni , uneven_basis = None ):
70
+ def _get_interpolated_slices (
71
+ self ,
72
+ dsets : dict [str , h5py .Dataset ],
73
+ ni : int ,
74
+ uneven_basis : Optional [str ] = None
75
+ ) -> dict [str , np .ndarray ]:
55
76
"""
56
77
For 2D data on an uneven grid, interpolates that data on to an evenly spaced grid.
57
78
58
79
# Arguments
59
- dsets (dict) :
60
- A dictionary of links to dedalus output tasks in hdf5 files.
61
- ni (int) :
62
- The index of the slice to be interpolate.
63
- uneven_basis (string, optional) :
64
- The basis on which the grid has uneven spacing.
80
+ dsets : A dictionary of links to dedalus output tasks in hdf5 files.
81
+ ni : The index of the slice to be interpolate.
82
+ uneven_basis : The basis on which the grid has uneven spacing.
65
83
"""
66
84
#Read data
67
85
bases = self .current_bases
@@ -87,17 +105,19 @@ def _get_interpolated_slices(self, dsets, ni, uneven_basis=None):
87
105
88
106
return file_data
89
107
90
- def _get_interpolated_volumes (self , dsets , ni , uneven_basis = None ):
108
+ def _get_interpolated_volumes (
109
+ self ,
110
+ dsets : dict [str , h5py .Dataset ],
111
+ ni : int ,
112
+ uneven_basis : Optional [str ] = None
113
+ ) -> dict [str , np .ndarray ]:
91
114
"""
92
115
For 3D data on an uneven grid, interpolates that data on to an evenly spaced grid.
93
116
94
117
# Arguments
95
- dsets (dict) :
96
- A dictionary of links to dedalus output tasks in hdf5 files.
97
- ni (int) :
98
- The index of the field to be interpolate.
99
- uneven_basis (string, optional) :
100
- The basis on which the grid has uneven spacing.
118
+ dsets : A dictionary of links to dedalus output tasks in hdf5 files.
119
+ ni : The index of the field to be interpolate.
120
+ uneven_basis : The basis on which the grid has uneven spacing.
101
121
"""
102
122
#Read data
103
123
bases = self .current_bases
@@ -142,29 +162,28 @@ def _get_interpolated_volumes(self, dsets, ni, uneven_basis=None):
142
162
print ('interpolating {} ({}/{})...' .format (k , i + 1 , file_data [k ].shape [0 ]))
143
163
stdout .flush ()
144
164
if uneven_index is None :
145
- file_data [k ][i ,:] = tdsets [k ][ni ][i ,:]
165
+ file_data [k ][i ,:] = dsets [k ][ni ][i ,:]
146
166
elif uneven_index == 2 :
147
167
for j in range (file_data [k ].shape [- 2 ]): # loop over y
148
- interp = RegularGridInterpolator ((x .flatten (), z .flatten ()), tsk [k ][i ,:,j ,:], method = 'linear' )
168
+ interp = RegularGridInterpolator ((x .flatten (), z .flatten ()), dsets [k ][i ,:,j ,:], method = 'linear' )
149
169
file_data [k ][i ,:,j ,:] = interp ((exx , ezz ))
150
170
else :
151
171
for j in range (file_data [k ].shape [- 1 ]): # loop over z
152
- interp = RegularGridInterpolator ((x .flatten (), y .flatten ()), tsk [k ][i ,:,:,j ], method = 'linear' )
172
+ interp = RegularGridInterpolator ((x .flatten (), y .flatten ()), dsets [k ][i ,:,:,j ], method = 'linear' )
153
173
file_data [k ][i ,:,:,j ] = interp ((exx , eyy ))
154
174
155
175
return file_data
156
176
157
- def _get_bounds (self , pdf_list ) :
177
+ def _get_bounds (self , pdf_list : list [ str ]) -> dict [ str , np . ndarray ] :
158
178
"""
159
179
Finds the global minimum and maximum value of fields for determing PDF range.
160
180
161
181
Arguments
162
182
---------
163
- pdf_list : list
164
- A list of fields for which to calculate the global minimum and maximum.
183
+ pdf_list : A list of fields for which to calculate the global minimum and maximum.
165
184
"""
166
185
with self .my_sync :
167
- if self .idle : return
186
+ if self .idle : return {}
168
187
169
188
bounds = OrderedDict ()
170
189
for field in pdf_list :
@@ -195,7 +214,14 @@ def _get_bounds(self, pdf_list):
195
214
return bounds
196
215
197
216
198
- def calculate_pdfs (self , pdf_list , bins = 100 , threeD = False , bases = ['x' , 'z' ], ** kwargs ):
217
+ def calculate_pdfs (
218
+ self ,
219
+ pdf_list : list [str ],
220
+ bins : int = 100 ,
221
+ threeD : bool = False ,
222
+ bases : list [str ]= ['x' , 'z' ],
223
+ uneven_basis : Optional [str ] = None ,
224
+ ) -> None :
199
225
"""
200
226
Calculate probability distribution functions of the specified tasks.
201
227
@@ -206,10 +232,9 @@ def calculate_pdfs(self, pdf_list, bins=100, threeD=False, bases=['x', 'z'], **k
206
232
The number of bins the PDF (histogram) should have
207
233
threeD (bool, optional) :
208
234
If True, find PDF of a 3D volume
209
- bases (list, optional) :
210
- A list of strings of the bases over which the simulation information spans.
235
+ bases : A list of strings of the bases over which the simulation information spans.
211
236
Should have 2 elements if threeD is False, 3 elements if threeD is True.
212
- **kwargs : additional keyword arguments for the self._get_interpolated_slices() function .
237
+ uneven_basis : The basis on which the grid has uneven spacing, if any .
213
238
"""
214
239
self .current_bases = bases
215
240
bounds = self ._get_bounds (pdf_list )
@@ -228,13 +253,13 @@ def calculate_pdfs(self, pdf_list, bins=100, threeD=False, bases=['x', 'z'], **k
228
253
229
254
# Interpolate data onto a regular grid
230
255
if threeD :
231
- file_data = self ._get_interpolated_volumes (dsets , ni , ** kwargs )
256
+ file_data = self ._get_interpolated_volumes (dsets , ni , uneven_basis = uneven_basis )
232
257
else :
233
- file_data = self ._get_interpolated_slices (dsets , ni , ** kwargs )
258
+ file_data = self ._get_interpolated_slices (dsets , ni , uneven_basis = uneven_basis )
234
259
235
260
# Create histograms of data
236
261
for field in pdf_list :
237
- hist , bin_vals = np .histogram (file_data [field ], bins = bins , range = bounds [field ])
262
+ hist , bin_vals = np .histogram (file_data [field ], bins = bins , range = tuple ( bounds [field ]) )
238
263
histograms [field ] += hist
239
264
bin_edges [field ] = bin_vals
240
265
@@ -254,19 +279,35 @@ def calculate_pdfs(self, pdf_list, bins=100, threeD=False, bases=['x', 'z'], **k
254
279
self ._calculate_pdf_statistics ()
255
280
256
281
257
- def plot_pdfs (self , dpi = 150 , ** kwargs ):
282
+ def plot_pdfs (
283
+ self ,
284
+ dpi : int = 150 ,
285
+ col_inch : float = 3 ,
286
+ row_inch : float = 3 ,
287
+ ) -> None :
258
288
"""
259
289
Plot the probability distribution functions and save them to file.
260
290
261
291
# Arguments
262
- dpi (int, optional) :
263
- Pixel density of output image .
264
- **kwargs : additional keyword arguments for RegularPlotGrid()
292
+ dpi : Pixel density of output image.
293
+ col_inch : Width of each column in inches .
294
+ row_inch : Height of each row in inches.
265
295
"""
266
296
with self .my_sync :
267
297
if self .comm .rank != 0 : return
268
298
269
- grid = RegularPlotGrid (num_rows = 1 ,num_cols = 1 , ** kwargs )
299
+ grid = RegularPlotGrid (
300
+ num_rows = 1 ,
301
+ num_cols = 1 ,
302
+ cbar = False ,
303
+ polar = False ,
304
+ mollweide = False ,
305
+ orthographic = False ,
306
+ threeD = False ,
307
+ col_inch = col_inch ,
308
+ row_inch = row_inch ,
309
+ pad_factor = 10
310
+ )
270
311
ax = grid .axes ['ax_0-0' ]
271
312
272
313
for k , data in self .pdfs .items ():
@@ -290,7 +331,7 @@ def plot_pdfs(self, dpi=150, **kwargs):
290
331
291
332
self ._save_pdfs ()
292
333
293
- def _save_pdfs (self ):
334
+ def _save_pdfs (self ) -> None :
294
335
"""
295
336
Save PDFs to file. For each PDF, e.g., 'entropy' and 'w', the file will have a dataset with:
296
337
xs - the x-values of the PDF
@@ -303,8 +344,8 @@ def _save_pdfs(self):
303
344
pdf , xs , dx = data
304
345
this_group = f .create_group (k )
305
346
for d , n in ((pdf , 'pdf' ), (xs , 'xs' )):
306
- dset = this_group .create_dataset (name = n , shape = d .shape , dtype = np .float64 )
347
+ this_group .create_dataset (name = n , shape = d .shape , dtype = np .float64 )
307
348
f ['{:s}/{:s}' .format (k , n )][:] = d
308
- dset = this_group .create_dataset (name = 'dx' , shape = (1 ,), dtype = np .float64 )
349
+ this_group .create_dataset (name = 'dx' , shape = (1 ,), dtype = np .float64 )
309
350
f ['{:s}/dx' .format (k )][0 ] = dx
310
351
0 commit comments