Skip to content

Commit 81a8f92

Browse files
committed
Type checking seems to work in 3.9; commiting auto-github-workflows
1 parent 7b76131 commit 81a8f92

File tree

12 files changed

+367
-145
lines changed

12 files changed

+367
-145
lines changed

.github/workflows/mypy.yml

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
name: Mypy Type Checking
2+
3+
on: [push, pull_request]
4+
5+
jobs:
6+
mypy:
7+
runs-on: ubuntu-latest
8+
strategy:
9+
matrix:
10+
python-version: ["3.9", "3.10", "3.11", "3.12"]
11+
12+
steps:
13+
- uses: actions/checkout@v2
14+
- name: Set up Python ${{ matrix.python-version }}
15+
uses: actions/setup-python@v2
16+
with:
17+
python-version: ${{ matrix.python-version }}
18+
- name: Install dependencies
19+
run: |
20+
python -m pip install --upgrade pip
21+
pip install mypy
22+
- name: Run mypy
23+
run: |
24+
mypy --config-file mypy-${{ matrix.python-version }}.ini

examples/d3/ivp_ball_internally_heated_convection/plot_equatorial_slices.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
radius = float(args['--radius'])
3434

3535
# Create Plotter object, tell it which fields to plot
36-
plotter = SlicePlotter(root_dir, file_dir=data_dir, out_name=out_name, start_file=start_file, n_files=n_files)
36+
plotter = SlicePlotter(root_dir, sub_dir=data_dir, out_name=out_name, start_file=start_file, num_files=n_files)
3737
plotter_kwargs = { 'col_inch' : int(args['--col_inch']), 'row_inch' : int(args['--row_inch']), 'pad_factor' : 10 }
3838

3939
# remove_x_mean option removes the (numpy horizontal mean) over phi

mypy-3.10.ini

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[mypy]
2+
python_version = 3.10
3+
disallow_untyped_calls = True
4+
disallow_untyped_defs = True
5+
disallow_incomplete_defs = True
6+
check_untyped_defs = True
7+
ignore_missing_imports = True
8+
files = ./plotpal/,./*.py
9+
exclude = ./examples
10+

mypy-3.11.ini

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[mypy]
2+
python_version = 3.11
3+
disallow_untyped_calls = True
4+
disallow_untyped_defs = True
5+
disallow_incomplete_defs = True
6+
check_untyped_defs = True
7+
ignore_missing_imports = True
8+
files = ./plotpal/,./*.py
9+
exclude = ./examples
10+

mypy-3.12.ini

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[mypy]
2+
python_version = 3.12
3+
disallow_untyped_calls = True
4+
disallow_untyped_defs = True
5+
disallow_incomplete_defs = True
6+
check_untyped_defs = True
7+
ignore_missing_imports = True
8+
files = ./plotpal/,./*.py
9+
exclude = ./examples
10+

mypy-3.9.ini

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[mypy]
2+
python_version = 3.9
3+
disallow_untyped_calls = True
4+
disallow_untyped_defs = True
5+
disallow_incomplete_defs = True
6+
check_untyped_defs = True
7+
ignore_missing_imports = True
8+
files = ./plotpal/,./*.py
9+
exclude = ./examples
10+

plotpal/file_reader.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
self.idle: dict[str, bool] = OrderedDict() # Whether or not this processor is idle for each file type
9090
self._distribute_writes(distribution, chunk_size=chunk_size)
9191

92-
def _distribute_writes(self, distribution: str, chunk_size=100) -> None:
92+
def _distribute_writes(self, distribution: str, chunk_size: int = 100) -> None:
9393
"""
9494
Distribute writes (or files) across MPI processes according to the specified rule.
9595
@@ -208,7 +208,7 @@ def __init__(
208208
chunk_size=chunk_size
209209
)
210210

211-
def _distribute_writes(self, distribution: str, chunk_size=100) -> None:
211+
def _distribute_writes(self, distribution: str, chunk_size: int = 100) -> None:
212212
super()._distribute_writes(distribution=distribution, chunk_size=chunk_size)
213213
self.roll_starts, self.roll_counts = OrderedDict(), OrderedDict()
214214

@@ -382,7 +382,7 @@ def writes_remain(self) -> bool:
382382
def get_dsets(
383383
self,
384384
tasks: list[str],
385-
verbose=True
385+
verbose: bool = True
386386
) -> tuple[dict[str, Union[h5py.Dataset, RolledDset]], int]:
387387
""" Given a list of task strings, returns a dictionary of the associated datasets and the dset index of the current write. """
388388
if not self.idle:

plotpal/pdfs.py

+91-50
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
1-
import os
21
import logging
32
from collections import OrderedDict
43
from sys import stdout
5-
from sys import path
4+
from typing import Optional
65

76
import numpy as np
8-
import h5py
7+
import h5py #type: ignore
98
from mpi4py import MPI
10-
from scipy.interpolate import RegularGridInterpolator
9+
from scipy.interpolate import RegularGridInterpolator #type: ignore
1110
import matplotlib
1211
matplotlib.use('Agg')
13-
import matplotlib.pyplot as plt
1412
matplotlib.rcParams.update({'font.size': 9})
1513

16-
from dedalus.tools.parallel import Sync
17-
1814
from plotpal.file_reader import SingleTypeReader, match_basis
1915
from plotpal.plot_grid import RegularPlotGrid
2016

@@ -30,16 +26,36 @@ class PdfPlotter(SingleTypeReader):
3026
that basis is evenly interpolated to avoid skewing of the distribution by uneven grid sampling.
3127
"""
3228

33-
def __init__(self, *args, **kwargs):
29+
def __init__(
30+
self,
31+
run_dir: str,
32+
sub_dir: str,
33+
out_name: str,
34+
distribution: str = 'even-write',
35+
num_files: Optional[int] = None,
36+
roll_writes: Optional[int] = None,
37+
start_file: int = 1,
38+
global_comm: MPI.Intracomm = MPI.COMM_WORLD,
39+
chunk_size: int = 1000
40+
):
3441
"""
3542
Initializes the PDF plotter.
3643
"""
37-
super(PdfPlotter, self).__init__(*args, distribution='even-write', **kwargs)
38-
self.pdfs = OrderedDict()
39-
self.pdf_stats = OrderedDict()
40-
41-
42-
def _calculate_pdf_statistics(self):
44+
super(PdfPlotter, self).__init__(
45+
run_dir=run_dir,
46+
sub_dir=sub_dir,
47+
out_name=out_name,
48+
distribution=distribution,
49+
num_files=num_files,
50+
roll_writes=roll_writes,
51+
start_file=start_file,
52+
global_comm=global_comm,
53+
chunk_size=chunk_size
54+
)
55+
self.pdfs: dict[str, tuple[np.ndarray, np.ndarray, np.ndarray]] = OrderedDict()
56+
self.pdf_stats: dict[str, tuple[float, float, float, float]] = OrderedDict()
57+
58+
def _calculate_pdf_statistics(self) -> None:
4359
""" Calculate statistics of the PDFs stored in self.pdfs. Store results in self.pdf_stats. """
4460
for k, data in self.pdfs.items():
4561
pdf, x_vals, dx = data
@@ -51,17 +67,19 @@ def _calculate_pdf_statistics(self):
5167
self.pdf_stats[k] = (mean, stdev, skew, kurt)
5268

5369

54-
def _get_interpolated_slices(self, dsets, ni, uneven_basis=None):
70+
def _get_interpolated_slices(
71+
self,
72+
dsets: dict[str, h5py.Dataset],
73+
ni: int,
74+
uneven_basis: Optional[str] = None
75+
) -> dict[str, np.ndarray]:
5576
"""
5677
For 2D data on an uneven grid, interpolates that data on to an evenly spaced grid.
5778
5879
# Arguments
59-
dsets (dict) :
60-
A dictionary of links to dedalus output tasks in hdf5 files.
61-
ni (int) :
62-
The index of the slice to be interpolate.
63-
uneven_basis (string, optional) :
64-
The basis on which the grid has uneven spacing.
80+
dsets : A dictionary of links to dedalus output tasks in hdf5 files.
81+
ni : The index of the slice to be interpolate.
82+
uneven_basis : The basis on which the grid has uneven spacing.
6583
"""
6684
#Read data
6785
bases = self.current_bases
@@ -87,17 +105,19 @@ def _get_interpolated_slices(self, dsets, ni, uneven_basis=None):
87105

88106
return file_data
89107

90-
def _get_interpolated_volumes(self, dsets, ni, uneven_basis=None):
108+
def _get_interpolated_volumes(
109+
self,
110+
dsets: dict[str, h5py.Dataset],
111+
ni: int,
112+
uneven_basis: Optional[str] = None
113+
) -> dict[str, np.ndarray]:
91114
"""
92115
For 3D data on an uneven grid, interpolates that data on to an evenly spaced grid.
93116
94117
# Arguments
95-
dsets (dict) :
96-
A dictionary of links to dedalus output tasks in hdf5 files.
97-
ni (int) :
98-
The index of the field to be interpolate.
99-
uneven_basis (string, optional) :
100-
The basis on which the grid has uneven spacing.
118+
dsets : A dictionary of links to dedalus output tasks in hdf5 files.
119+
ni : The index of the field to be interpolate.
120+
uneven_basis : The basis on which the grid has uneven spacing.
101121
"""
102122
#Read data
103123
bases = self.current_bases
@@ -142,29 +162,28 @@ def _get_interpolated_volumes(self, dsets, ni, uneven_basis=None):
142162
print('interpolating {} ({}/{})...'.format(k, i+1, file_data[k].shape[0]))
143163
stdout.flush()
144164
if uneven_index is None:
145-
file_data[k][i,:] = tdsets[k][ni][i,:]
165+
file_data[k][i,:] = dsets[k][ni][i,:]
146166
elif uneven_index == 2:
147167
for j in range(file_data[k].shape[-2]): # loop over y
148-
interp = RegularGridInterpolator((x.flatten(), z.flatten()), tsk[k][i,:,j,:], method='linear')
168+
interp = RegularGridInterpolator((x.flatten(), z.flatten()), dsets[k][i,:,j,:], method='linear')
149169
file_data[k][i,:,j,:] = interp((exx, ezz))
150170
else:
151171
for j in range(file_data[k].shape[-1]): # loop over z
152-
interp = RegularGridInterpolator((x.flatten(), y.flatten()), tsk[k][i,:,:,j], method='linear')
172+
interp = RegularGridInterpolator((x.flatten(), y.flatten()), dsets[k][i,:,:,j], method='linear')
153173
file_data[k][i,:,:,j] = interp((exx, eyy))
154174

155175
return file_data
156176

157-
def _get_bounds(self, pdf_list):
177+
def _get_bounds(self, pdf_list: list[str]) -> dict[str, np.ndarray]:
158178
"""
159179
Finds the global minimum and maximum value of fields for determing PDF range.
160180
161181
Arguments
162182
---------
163-
pdf_list : list
164-
A list of fields for which to calculate the global minimum and maximum.
183+
pdf_list : A list of fields for which to calculate the global minimum and maximum.
165184
"""
166185
with self.my_sync:
167-
if self.idle : return
186+
if self.idle : return {}
168187

169188
bounds = OrderedDict()
170189
for field in pdf_list:
@@ -195,7 +214,14 @@ def _get_bounds(self, pdf_list):
195214
return bounds
196215

197216

198-
def calculate_pdfs(self, pdf_list, bins=100, threeD=False, bases=['x', 'z'], **kwargs):
217+
def calculate_pdfs(
218+
self,
219+
pdf_list: list[str],
220+
bins: int = 100,
221+
threeD: bool = False,
222+
bases: list[str]=['x', 'z'],
223+
uneven_basis: Optional[str] = None,
224+
) -> None:
199225
"""
200226
Calculate probability distribution functions of the specified tasks.
201227
@@ -206,10 +232,9 @@ def calculate_pdfs(self, pdf_list, bins=100, threeD=False, bases=['x', 'z'], **k
206232
The number of bins the PDF (histogram) should have
207233
threeD (bool, optional) :
208234
If True, find PDF of a 3D volume
209-
bases (list, optional) :
210-
A list of strings of the bases over which the simulation information spans.
235+
bases : A list of strings of the bases over which the simulation information spans.
211236
Should have 2 elements if threeD is False, 3 elements if threeD is True.
212-
**kwargs : additional keyword arguments for the self._get_interpolated_slices() function.
237+
uneven_basis : The basis on which the grid has uneven spacing, if any.
213238
"""
214239
self.current_bases = bases
215240
bounds = self._get_bounds(pdf_list)
@@ -228,13 +253,13 @@ def calculate_pdfs(self, pdf_list, bins=100, threeD=False, bases=['x', 'z'], **k
228253

229254
# Interpolate data onto a regular grid
230255
if threeD:
231-
file_data = self._get_interpolated_volumes(dsets, ni, **kwargs)
256+
file_data = self._get_interpolated_volumes(dsets, ni, uneven_basis=uneven_basis)
232257
else:
233-
file_data = self._get_interpolated_slices(dsets, ni, **kwargs)
258+
file_data = self._get_interpolated_slices(dsets, ni, uneven_basis=uneven_basis)
234259

235260
# Create histograms of data
236261
for field in pdf_list:
237-
hist, bin_vals = np.histogram(file_data[field], bins=bins, range=bounds[field])
262+
hist, bin_vals = np.histogram(file_data[field], bins=bins, range=tuple(bounds[field]))
238263
histograms[field] += hist
239264
bin_edges[field] = bin_vals
240265

@@ -254,19 +279,35 @@ def calculate_pdfs(self, pdf_list, bins=100, threeD=False, bases=['x', 'z'], **k
254279
self._calculate_pdf_statistics()
255280

256281

257-
def plot_pdfs(self, dpi=150, **kwargs):
282+
def plot_pdfs(
283+
self,
284+
dpi: int = 150,
285+
col_inch: float = 3,
286+
row_inch: float = 3,
287+
) -> None:
258288
"""
259289
Plot the probability distribution functions and save them to file.
260290
261291
# Arguments
262-
dpi (int, optional) :
263-
Pixel density of output image.
264-
**kwargs : additional keyword arguments for RegularPlotGrid()
292+
dpi : Pixel density of output image.
293+
col_inch : Width of each column in inches.
294+
row_inch : Height of each row in inches.
265295
"""
266296
with self.my_sync:
267297
if self.comm.rank != 0: return
268298

269-
grid = RegularPlotGrid(num_rows=1,num_cols=1, **kwargs)
299+
grid = RegularPlotGrid(
300+
num_rows=1,
301+
num_cols=1,
302+
cbar=False,
303+
polar=False,
304+
mollweide=False,
305+
orthographic=False,
306+
threeD=False,
307+
col_inch=col_inch,
308+
row_inch=row_inch,
309+
pad_factor=10
310+
)
270311
ax = grid.axes['ax_0-0']
271312

272313
for k, data in self.pdfs.items():
@@ -290,7 +331,7 @@ def plot_pdfs(self, dpi=150, **kwargs):
290331

291332
self._save_pdfs()
292333

293-
def _save_pdfs(self):
334+
def _save_pdfs(self) -> None:
294335
"""
295336
Save PDFs to file. For each PDF, e.g., 'entropy' and 'w', the file will have a dataset with:
296337
xs - the x-values of the PDF
@@ -303,8 +344,8 @@ def _save_pdfs(self):
303344
pdf, xs, dx = data
304345
this_group = f.create_group(k)
305346
for d, n in ((pdf, 'pdf'), (xs, 'xs')):
306-
dset = this_group.create_dataset(name=n, shape=d.shape, dtype=np.float64)
347+
this_group.create_dataset(name=n, shape=d.shape, dtype=np.float64)
307348
f['{:s}/{:s}'.format(k, n)][:] = d
308-
dset = this_group.create_dataset(name='dx', shape=(1,), dtype=np.float64)
349+
this_group.create_dataset(name='dx', shape=(1,), dtype=np.float64)
309350
f['{:s}/dx'.format(k)][0] = dx
310351

0 commit comments

Comments
 (0)