Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run pyupgrade #4083

Merged
merged 3 commits into from
Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ repos:
name: nbqa-isort
alias: nbqa-isort
additional_dependencies: ['isort']
- repo: https://github.com/asottile/pyupgrade
rev: v2.7.2
hooks:
- id: pyupgrade
args: ['--py36-plus']
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# pymc3 documentation build configuration file, created by
# sphinx-quickstart on Sat Dec 26 14:40:23 2015.
Expand Down
14 changes: 7 additions & 7 deletions docs/source/sphinxext/gallery_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def __init__(self, filename, target_dir):
self.basename = os.path.basename(filename)
self.stripped_name = os.path.splitext(self.basename)[0]
self.output_html = os.path.join(
"..", "notebooks", "{}.html".format(self.stripped_name)
"..", "notebooks", f"{self.stripped_name}.html"
)
self.image_dir = os.path.join(target_dir, "_images")
self.png_path = os.path.join(
self.image_dir, "{}.png".format(self.stripped_name)
self.image_dir, f"{self.stripped_name}.png"
)
with open(filename, "r") as fid:
with open(filename) as fid:
self.json_source = json.load(fid)
self.pagetitle = self.extract_title()
self.default_image_loc = DEFAULT_IMG_LOC
Expand All @@ -89,7 +89,7 @@ def __init__(self, filename, target_dir):

self.gen_previews()
else:
print("skipping {0}".format(filename))
print(f"skipping {filename}")

def extract_preview_pic(self):
"""By default, just uses the last image in the notebook."""
Expand Down Expand Up @@ -136,7 +136,7 @@ def build_gallery(srcdir, gallery):
working_dir = os.getcwd()
os.chdir(srcdir)
static_dir = os.path.join(srcdir, "_static")
target_dir = os.path.join(srcdir, "nb_{}".format(gallery))
target_dir = os.path.join(srcdir, f"nb_{gallery}")
image_dir = os.path.join(target_dir, "_images")
source_dir = os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(srcdir)), "notebooks")
Expand Down Expand Up @@ -182,8 +182,8 @@ def build_gallery(srcdir, gallery):
"thumb": os.path.basename(default_png_path),
}

js_file = os.path.join(image_dir, "gallery_{}_contents.js".format(gallery))
with open(table_of_contents_file, "r") as toc:
js_file = os.path.join(image_dir, f"gallery_{gallery}_contents.js")
with open(table_of_contents_file) as toc:
table_of_contents = toc.read()

js_contents = "Gallery.examples = {}\n{}".format(
Expand Down
8 changes: 4 additions & 4 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ def __getitem__(self, idx):
return self.get_sampler_stats(var, burn=burn, thin=thin)
raise KeyError("Unknown variable %s" % var)

_attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
'supports_sampler_stats', '_report'])
_attrs = {'_straces', 'varnames', 'chains', 'stat_names',
'supports_sampler_stats', '_report'}

def __getattr__(self, name):
# Avoid infinite recursion when called before __init__
Expand Down Expand Up @@ -417,7 +417,7 @@ def add_values(self, vals, overwrite=False) -> None:
self.varnames.remove(k)
new_var = 0
else:
raise ValueError("Variable name {} already exists.".format(k))
raise ValueError(f"Variable name {k} already exists.")

self.varnames.append(k)

Expand Down Expand Up @@ -448,7 +448,7 @@ def remove_values(self, name):
"""
varnames = self.varnames
if name not in varnames:
raise KeyError("Unknown variable {}".format(name))
raise KeyError(f"Unknown variable {name}")
self.varnames.remove(name)
chains = self._straces
for chain in chains.values():
Expand Down
2 changes: 1 addition & 1 deletion pymc3/backends/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def sampler_vars(self, values):
data.create_dataset(varname, (self.draws,), dtype=dtype, maxshape=(None,))
elif data.keys() != sampler.keys():
raise ValueError(
"Sampler vars can't change, names incompatible: {} != {}".format(data.keys(), sampler.keys()))
f"Sampler vars can't change, names incompatible: {data.keys()} != {sampler.keys()}")
self.records_stats = True

def setup(self, draws, chain, sampler_vars=None):
Expand Down
2 changes: 1 addition & 1 deletion pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def load(self, model: Model) -> 'NDArray':
raise TraceDirectoryError("%s is not a trace directory" % self.directory)

new_trace = NDArray(model=model)
with open(self.metadata_path, 'r') as buff:
with open(self.metadata_path) as buff:
metadata = json.load(buff)

metadata['_stats'] = [{k: np.array(v) for k, v in stat.items()} for stat in metadata['_stats']]
Expand Down
8 changes: 4 additions & 4 deletions pymc3/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ def load(name, model=None):
db.connect()
varnames = _get_table_list(db.cursor)
if len(varnames) == 0:
raise ValueError(('Can not get variable list for database'
'`{}`'.format(name)))
raise ValueError('Can not get variable list for database'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh.

'`{}`'.format(name))
chains = _get_chain_list(db.cursor, varnames[0])

straces = []
Expand All @@ -367,14 +367,14 @@ def _get_table_list(cursor):


def _get_var_strs(cursor, varname):
cursor.execute('SELECT * FROM [{}]'.format(varname))
cursor.execute(f'SELECT * FROM [{varname}]')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boy howdy we should get rid of this backend

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes...

col_names = (col_descr[0] for col_descr in cursor.description)
return [name for name in col_names if name.startswith('v')]


def _get_chain_list(cursor, varname):
"""Return a list of sorted chains for `varname`."""
cursor.execute('SELECT DISTINCT chain FROM [{}]'.format(varname))
cursor.execute(f'SELECT DISTINCT chain FROM [{varname}]')
chains = sorted([chain[0] for chain in cursor.fetchall()])
return chains

Expand Down
6 changes: 3 additions & 3 deletions pymc3/backends/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def setup(self, draws, chain):
self._fh.close()

self.chain = chain
self.filename = os.path.join(self.name, 'chain-{}.csv'.format(chain))
self.filename = os.path.join(self.name, f'chain-{chain}.csv')

cnames = [fv for v in self.varnames for fv in self.flat_names[v]]

Expand Down Expand Up @@ -201,7 +201,7 @@ def load(name, model=None):
files = glob(os.path.join(name, 'chain-*.csv'))

if len(files) == 0:
raise ValueError('No files present in directory {}'.format(name))
raise ValueError(f'No files present in directory {name}')

straces = []
for f in files:
Expand Down Expand Up @@ -249,7 +249,7 @@ def dump(name, trace, chains=None):
chains = trace.chains

for chain in chains:
filename = os.path.join(name, 'chain-{}.csv'.format(chain))
filename = os.path.join(name, f'chain-{chain}.csv')
df = ttab.trace_to_dataframe(
trace, chains=chain, include_transformed=True)
df.to_csv(filename, index=False)
6 changes: 3 additions & 3 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def assert_negative_support(var, label, distname, value=-1e-6):
support = False

if np.any(support):
msg = "The variable specified for {0} has negative support for {1}, ".format(
msg = "The variable specified for {} has negative support for {}, ".format(
label, distname
)
msg += "likely making it unsuitable for this parameter."
Expand Down Expand Up @@ -294,7 +294,7 @@ def logcdf(self, value):
tt.switch(
tt.eq(value, self.upper),
0,
tt.log((value - self.lower)) - tt.log((self.upper - self.lower)),
tt.log(value - self.lower) - tt.log(self.upper - self.lower),
),
)

Expand Down Expand Up @@ -1887,7 +1887,7 @@ class StudentT(Continuous):

def __init__(self, nu, mu=0, lam=None, sigma=None, sd=None, *args, **kwargs):
super().__init__(*args, **kwargs)
super(StudentT, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
if sd is not None:
sigma = sd
warnings.warn("sd is deprecated, use sigma instead", DeprecationWarning)
Expand Down
10 changes: 5 additions & 5 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __new__(cls, name, *args, **kwargs):
"for a standalone distribution.")

if not isinstance(name, string_types):
raise TypeError("Name needs to be a string but got: {}".format(name))
raise TypeError(f"Name needs to be a string but got: {name}")

data = kwargs.pop('observed', None)
cls.data = data
Expand Down Expand Up @@ -728,7 +728,7 @@ def draw_values(params, point=None, size=None):
# test_distributions_random::TestDrawValues::test_draw_order fails without it
# The remaining params that must be drawn are all hashable
to_eval = set()
missing_inputs = set([j for j, p in symbolic_params])
missing_inputs = {j for j, p in symbolic_params}
while to_eval or missing_inputs:
if to_eval == missing_inputs:
raise ValueError('Cannot resolve inputs for {}'.format([get_var_name(params[j]) for j in to_eval]))
Expand Down Expand Up @@ -828,7 +828,7 @@ def vectorize_theano_function(f, inputs, output):
"""
inputs_signatures = ",".join(
[
get_vectorize_signature(var, var_name="i_{}".format(input_ind))
get_vectorize_signature(var, var_name=f"i_{input_ind}")
for input_ind, var in enumerate(inputs)
]
)
Expand All @@ -846,9 +846,9 @@ def get_vectorize_signature(var, var_name="i"):
return "()"
else:
sig = ",".join(
["{}_{}".format(var_name, axis_ind) for axis_ind in range(var.ndim)]
[f"{var_name}_{axis_ind}" for axis_ind in range(var.ndim)]
)
return "({})".format(sig)
return f"({sig})"


def _draw_value(param, point=None, givens=None, size=None):
Expand Down
2 changes: 1 addition & 1 deletion pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, w, comp_dists, *args, **kwargs):
isinstance(comp_dists, Distribution)
or (
isinstance(comp_dists, Iterable)
and all((isinstance(c, Distribution) for c in comp_dists))
and all(isinstance(c, Distribution) for c in comp_dists)
)
):
raise TypeError(
Expand Down
10 changes: 5 additions & 5 deletions pymc3/distributions/posterior_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
assert point_list is None and dict is None
self.data = {} # Dict[str, np.ndarray]
self._len = sum(
(len(multi_trace._straces[chain]) for chain in multi_trace.chains)
len(multi_trace._straces[chain]) for chain in multi_trace.chains
)
self.varnames = multi_trace.varnames
for vn in multi_trace.varnames:
Expand Down Expand Up @@ -153,15 +153,15 @@ def __getitem__(self, item: Union[slice, int]) -> "_TraceDict":

def __getitem__(self, item):
if isinstance(item, str):
return super(_TraceDict, self).__getitem__(item)
return super().__getitem__(item)
elif isinstance(item, slice):
return self._extract_slice(item)
elif isinstance(item, int):
return _TraceDict(
dict={k: np.atleast_1d(v[item]) for k, v in self.data.items()}
)
elif hasattr(item, "name"):
return super(_TraceDict, self).__getitem__(item.name)
return super().__getitem__(item.name)
else:
raise IndexError("Illegal index %s for _TraceDict" % str(item))

Expand Down Expand Up @@ -242,7 +242,7 @@ def fast_sample_posterior_predictive(
"Should not specify both keep_size and samples arguments"
)

if isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
if isinstance(trace, list) and all(isinstance(x, dict) for x in trace):
_trace = _TraceDict(point_list=trace)
elif isinstance(trace, MultiTrace):
_trace = _TraceDict(multi_trace=trace)
Expand Down Expand Up @@ -454,7 +454,7 @@ def draw_values(self) -> List[np.ndarray]:
# test_distributions_random::TestDrawValues::test_draw_order fails without it
# The remaining params that must be drawn are all hashable
to_eval: Set[int] = set()
missing_inputs: Set[int] = set([j for j, p in self.symbolic_params])
missing_inputs: Set[int] = {j for j, p in self.symbolic_params}

while to_eval or missing_inputs:
if to_eval == missing_inputs:
Expand Down
12 changes: 6 additions & 6 deletions pymc3/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ def _check_shape_type(shape):
shape = np.atleast_1d(shape)
for s in shape:
if isinstance(s, np.ndarray) and s.ndim > 0:
raise TypeError("Value {} is not a valid integer".format(s))
raise TypeError(f"Value {s} is not a valid integer")
o = int(s)
if o != s:
raise TypeError("Value {} is not a valid integer".format(s))
raise TypeError(f"Value {s} is not a valid integer")
out.append(o)
except Exception:
raise TypeError(
"Supplied value {} does not represent a valid shape".format(shape)
f"Supplied value {shape} does not represent a valid shape"
)
return tuple(out)

Expand Down Expand Up @@ -103,7 +103,7 @@ def shapes_broadcasting(*args, raise_exception=False):
if raise_exception:
raise ValueError(
"Supplied shapes {} do not broadcast together".format(
", ".join(["{}".format(a) for a in args])
", ".join([f"{a}" for a in args])
)
)
else:
Expand Down Expand Up @@ -165,7 +165,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
if broadcasted_shape is None:
raise ValueError(
"Cannot broadcast provided shapes {} given size: {}".format(
", ".join(["{}".format(s) for s in shapes]), size
", ".join([f"{s}" for s in shapes]), size
)
)
return broadcasted_shape
Expand All @@ -181,7 +181,7 @@ def broadcast_dist_samples_shape(shapes, size=None):
except ValueError:
raise ValueError(
"Cannot broadcast provided shapes {} given size: {}".format(
", ".join(["{}".format(s) for s in shapes]), size
", ".join([f"{s}" for s in shapes]), size
)
)
broadcastable_shapes = []
Expand Down
4 changes: 2 additions & 2 deletions pymc3/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
epsilon=1,
**kwargs,
):
"""
r"""
This class stores a function defined by the user in Python language.

function: function
Expand Down Expand Up @@ -125,7 +125,7 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
distance = self.distance.__name__

if formatting == "latex":
return f"$\\text{{{name}}} \sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
return f"$\\text{{{name}}} \\sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$"
else:
return f"{name} ~ Simulator({function}({params}), {distance}, {sum_stat})"

Expand Down
2 changes: 1 addition & 1 deletion pymc3/examples/samplers_mvnormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run(steppers, p):
print('{} samples across {} chains'.format(len(mt) * mt.nchains, mt.nchains))
traces[name] = mt
en = pm.ess(mt)
print('effective: {}\r\n'.format(en))
print(f'effective: {en}\r\n')
if USE_XY:
effn[name] = np.mean(en['x']) / len(mt) / mt.nchains
else:
Expand Down
12 changes: 6 additions & 6 deletions pymc3/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ class ShapeError(Exception):
"""Error that the shape of a variable is incorrect."""
def __init__(self, message, actual=None, expected=None):
if actual is not None and expected is not None:
super().__init__('{} (actual {} != expected {})'.format(message, actual, expected))
super().__init__(f'{message} (actual {actual} != expected {expected})')
elif actual is not None and expected is None:
super().__init__('{} (actual {})'.format(message, actual))
super().__init__(f'{message} (actual {actual})')
elif actual is None and expected is not None:
super().__init__('{} (expected {})'.format(message, expected))
super().__init__(f'{message} (expected {expected})')
else:
super().__init__(message)

Expand All @@ -58,10 +58,10 @@ class DtypeError(TypeError):
"""Error that the dtype of a variable is incorrect."""
def __init__(self, message, actual=None, expected=None):
if actual is not None and expected is not None:
super().__init__('{} (actual {} != expected {})'.format(message, actual, expected))
super().__init__(f'{message} (actual {actual} != expected {expected})')
elif actual is not None and expected is None:
super().__init__('{} (actual {})'.format(message, actual))
super().__init__(f'{message} (actual {actual})')
elif actual is None and expected is not None:
super().__init__('{} (expected {})'.format(message, expected))
super().__init__(f'{message} (expected {expected})')
else:
super().__init__(message)
Loading