Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup mypy errors in sampling.py #4327

Merged
merged 2 commits into from
Dec 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 31 additions & 41 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@

"""Functions for MCMC sampling."""

import collections.abc as abc
import logging
import pickle
import sys
import time
import warnings

from collections import defaultdict
from collections.abc import Iterable
from copy import copy
from typing import Any, Dict
from typing import Iterable as TIterable
from typing import List, Optional, Union, cast
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast

import arviz
import numpy as np
Expand Down Expand Up @@ -57,8 +55,8 @@
HamiltonianMC,
Metropolis,
Slice,
arraystep,
)
from pymc3.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc3.step_methods.hmc import quadpotential
from pymc3.util import (
chains_and_samples,
Expand Down Expand Up @@ -93,15 +91,19 @@
CategoricalGibbsMetropolis,
PGBART,
)
Step = Union[BlockedStep, CompoundStep]

ArrayLike = Union[np.ndarray, List[float]]
PointType = Dict[str, np.ndarray]
PointList = List[PointType]
Backend = Union[BaseTrace, MultiTrace, NDArray]

_log = logging.getLogger("pymc3")


def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
def instantiate_steppers(
_model, steps: List[Step], selected_steps, step_kwargs=None
) -> Union[Step, List[Step]]:
"""Instantiate steppers assigned to the model variables.

This function is intended to be called automatically from ``sample()``, but
Expand Down Expand Up @@ -142,7 +144,7 @@ def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None):
raise ValueError("Unused step method arguments: %s" % unused_args)

if len(steps) == 1:
steps = steps[0]
return steps[0]
michaelosthege marked this conversation as resolved.
Show resolved Hide resolved

return steps

Expand Down Expand Up @@ -216,7 +218,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
return instantiate_steppers(model, steps, selected_steps, step_kwargs)


def _print_step_hierarchy(s, level=0):
def _print_step_hierarchy(s: Step, level=0) -> None:
if isinstance(s, CompoundStep):
_log.info(">" * level + "CompoundStep")
for i in s.methods:
Expand Down Expand Up @@ -447,7 +449,7 @@ def sample(
if random_seed is not None:
np.random.seed(random_seed)
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
if not isinstance(random_seed, Iterable):
if not isinstance(random_seed, abc.Iterable):
michaelosthege marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")

if not discard_tuned_samples and not return_inferencedata:
Expand Down Expand Up @@ -542,7 +544,7 @@ def sample(

has_population_samplers = np.any(
[
isinstance(m, arraystep.PopulationArrayStepShared)
isinstance(m, PopulationArrayStepShared)
for m in (step.methods if isinstance(step, CompoundStep) else [step])
]
)
Expand Down Expand Up @@ -706,7 +708,7 @@ def _sample_many(
trace: MultiTrace
Contains samples of all chains
"""
traces = []
traces: List[Backend] = []
for i in range(chains):
trace = _sample(
draws=draws,
Expand Down Expand Up @@ -1140,7 +1142,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
# has to be updated, therefore we identify the substeppers first.
population_steppers = []
for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]:
if isinstance(sm, arraystep.PopulationArrayStepShared):
if isinstance(sm, PopulationArrayStepShared):
population_steppers.append(sm)
while True:
incoming = secondary_end.recv()
Expand Down Expand Up @@ -1259,7 +1261,7 @@ def _prepare_iter_population(
population = [Point(start[c], model=model) for c in range(nchains)]

# 3. Set up the steppers
steppers = [None] * nchains
steppers: List[Step] = []
for c in range(nchains):
# need indepenent samplers for each chain
# it is important to copy the actual steppers (but not the delta_logp)
Expand All @@ -1269,9 +1271,9 @@ def _prepare_iter_population(
chainstep = copy(step)
# link population samplers to the shared population state
for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]:
if isinstance(sm, arraystep.PopulationArrayStepShared):
if isinstance(sm, PopulationArrayStepShared):
sm.link_population(population, c)
steppers[c] = chainstep
steppers.append(chainstep)

# 4. configure tracking of sampler stats
for c in range(nchains):
Expand Down Expand Up @@ -1349,7 +1351,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
steppers[c].report._finalize(strace)


def _choose_backend(trace, chain, **kwds):
def _choose_backend(trace, chain, **kwds) -> Backend:
"""Selects or creates a NDArray trace backend for a particular chain.

Parameters
Expand Down Expand Up @@ -1562,8 +1564,8 @@ class _DefaultTrace:
`insert()` method
"""

trace_dict = {} # type: Dict[str, np.ndarray]
_len = None # type: int
trace_dict: Dict[str, np.ndarray] = {}
_len: Optional[int] = None

def __init__(self, samples: int):
self._len = samples
Expand Down Expand Up @@ -1600,7 +1602,7 @@ def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
model: Optional[Model] = None,
vars: Optional[TIterable[Tensor]] = None,
vars: Optional[Iterable[Tensor]] = None,
var_names: Optional[List[str]] = None,
size: Optional[int] = None,
keep_size: Optional[bool] = False,
Expand Down Expand Up @@ -1885,8 +1887,7 @@ def sample_posterior_predictive_w(
def sample_prior_predictive(
samples=500,
model: Optional[Model] = None,
vars: Optional[TIterable[str]] = None,
var_names: Optional[TIterable[str]] = None,
var_names: Optional[Iterable[str]] = None,
random_seed=None,
) -> Dict[str, np.ndarray]:
"""Generate samples from the prior predictive distribution.
Expand All @@ -1896,9 +1897,6 @@ def sample_prior_predictive(
samples : int
Number of samples from the prior predictive to generate. Defaults to 500.
model : Model (optional if in ``with`` context)
vars : Iterable[str]
A list of names of variables for which to compute the posterior predictive
samples. *DEPRECATED* - Use ``var_names`` argument instead.
var_names : Iterable[str]
A list of names of variables for which to compute the posterior predictive
samples. Defaults to both observed and unobserved RVs.
Expand All @@ -1913,22 +1911,14 @@ def sample_prior_predictive(
"""
model = modelcontext(model)

if vars is None and var_names is None:
if var_names is None:
prior_pred_vars = model.observed_RVs
prior_vars = (
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
)
vars_ = [var.name for var in prior_vars + prior_pred_vars]
vars = set(vars_)
elif vars is None:
vars = var_names
vars_ = vars
elif vars is not None:
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
vars_ = vars
vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars}
else:
raise ValueError("Cannot supply both vars and var_names arguments.")
vars = cast(TIterable[str], vars) # tell mypy that vars cannot be None here.
vars_ = set(var_names)

if random_seed is not None:
np.random.seed(random_seed)
Expand All @@ -1940,8 +1930,8 @@ def sample_prior_predictive(
if data is None:
raise AssertionError("No variables sampled: attempting to sample %s" % names)

prior = {} # type: Dict[str, np.ndarray]
for var_name in vars:
prior: Dict[str, np.ndarray] = {}
for var_name in vars_:
michaelosthege marked this conversation as resolved.
Show resolved Hide resolved
if var_name in data:
prior[var_name] = data[var_name]
elif is_transformed_name(var_name):
Expand Down Expand Up @@ -2093,15 +2083,15 @@ def init_nuts(
var = np.ones_like(mean)
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
elif init == "advi+adapt_diag_grad":
approx = pm.fit(
approx: pm.MeanField = pm.fit(
random_seed=random_seed,
n=n_init,
method="advi",
model=model,
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
) # type: pm.MeanField
)
start = approx.sample(draws=chains)
start = list(start)
stds = approx.bij.rmap(approx.std.eval())
Expand All @@ -2119,7 +2109,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
) # type: pm.MeanField
)
start = approx.sample(draws=chains)
start = list(start)
stds = approx.bij.rmap(approx.std.eval())
Expand All @@ -2137,7 +2127,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
) # type: pm.MeanField
)
start = approx.sample(draws=chains)
start = list(start)
stds = approx.bij.rmap(approx.std.eval())
Expand Down
5 changes: 4 additions & 1 deletion pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

from enum import IntEnum, unique
from typing import Dict, List

import numpy as np

from numpy.random import uniform

from pymc3.blocking import ArrayOrdering, DictToArrayBijection
from pymc3.model import modelcontext
from pymc3.model import PyMC3Variable, modelcontext
from pymc3.step_methods.compound import CompoundStep
from pymc3.theanof import inputvars
from pymc3.util import get_var_name
Expand All @@ -46,6 +47,8 @@ class Competence(IntEnum):
class BlockedStep:

generates_stats = False
stats_dtypes: List[Dict[str, np.dtype]] = []
michaelosthege marked this conversation as resolved.
Show resolved Hide resolved
vars: List[PyMC3Variable] = []

def __new__(cls, *args, **kwargs):
blocked = kwargs.get("blocked")
Expand Down
4 changes: 0 additions & 4 deletions pymc3/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,6 @@ class MLDA(ArrayStepShared):
default_blocked = True
generates_stats = True

# stat data types are different, depending on the base sampler.
# these are assigned in the init method.
stats_dtypes = None

michaelosthege marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
coarse_models: List[Model],
Expand Down
5 changes: 2 additions & 3 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,9 +903,8 @@ def test_respects_shape(self):
with pm.Model():
mu = pm.Gamma("mu", 3, 1, shape=1)
goals = pm.Poisson("goals", mu, shape=shape)
with pytest.warns(DeprecationWarning):
trace1 = pm.sample_prior_predictive(10, vars=["mu", "goals"])
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"])
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
if shape == 2: # want to test shape as an int
shape = (2,)
assert trace1["goals"].shape == (10,) + shape
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ convention = numpy
[isort]
lines_between_types = 1
profile = black

[mypy]
ignore_missing_imports = True