Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge statespace module from https://github.com/jessegrabowski/pymc_statespace #174

Merged
merged 140 commits into from
Aug 18, 2023

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 29, 2023

This is the promised statespace PR. It's a WIP, for several reasons, but it already represents a MVP that can be used. I'll try to explain what I was going for with everything here. I expect that a lot needs to be changed, though, and I look forward to working hard to improve it!

Overall design goal

Statespace models are a pair of linear equations of the form:

$$\begin{align}x_{t+1} &= T_tx_t + c_t + R_t\epsilon_t, &\quad \epsilon_t &\sim N(0, Q_t) \\ y_t &= Z_tx_t + d_t + \eta_t &\quad \eta_t &\sim N(0, H_t) \end{align}$$

Where the first equation is called the "transition equation" and the second equation is called the "observation equation". The state vector $x_t$ contain both observed and unobserved states, while the observation vector $y_t$ contains only the states observed in data.

Since everything is linear and gaussian, the posterior distribution over the observed states is just a multivariate normal. Actually it can be shown that this is a special case of a guassian process. But, since we might have a lot of states or a lot of times, we can compute the posterior mean and covariance more efficiently via recursion. This is the Kalman Filter.

Sidebar 1: Why not just scan?

Given the advancements in scan lately, it's fair to ask why bother with filtering at all, and just let PyMC automatically infer the logp of these two equations. This might be possible, and it's something I'd like to explore. It will also let us push out beyond gaussian errors, which would be great. There was already a user on the forum looking for poisson distributed observations, for example.

For now it's not possible, because $Q$ might not be fully rank -- that is, not all states in $x_t$ need to be stochastic. We could solve this with clever slicing, though. A second problem is that the quantity $R_t\epsilon_t$ is not measurable. For gaussian errors this isn't a problem, because we can just fold it into the covariance matrix, and write $x_{t+1} = T_tx_t + \epsilon_t, \quad \epsilon_t \sim N(0, R_tQ_tR_t^T)$, but this identity doesn't hold generally.

Sidebar over

Basically, this PR has the following goals:

  1. Abstract away the matrices and allow users to just focus on model parameters.
  2. Offer computationally efficient solutions for computing the logp of these models
  3. Allow the models to be quickly sampled
  4. Alleviate headaches associated with post-estimation tasks, especially forecasting

Modules in this PR

I will briefly introduce what I've done in this PR, and try to justify my choices. I am confident they can be improved.

Core

The core module is responsible for representing arbitrary statespace models. There are two files, representation.py and statespace.py.

representation.py

This is the lowest-level object in the module. It is responsible for initializing and storing the matrices c, d, T, Z, R, H, Q, along with initial state and covariance x0 and P0. The module overloads __getitem__ and __setitem__ to allow the user to slice into matrices by name, for example mod['transition', 0, 0] gets the [0, 0] position of the T matrix.

In addition, it also has a bunch of checks and logic to handle time-varying matrices. If a user wants the state-space matrices to vary over time, it is necessary to duplicate them and store the whole stack to scan over later. If the model is not time varying, it automatically slices around the time dimension so the user is never confronted with it.

Users should never have to touch this, it's all just low-level machinery.

statespace.py

This is the base class for all statespace models. It's responsible for combining the statespace matrices with a kalman filter to make a logp graph. This is accomplished by the gather_required_random_variables, update, and build_statespace_graph methods. Only update will vary between models, and needs to be implemented. update is responsible for taking a flat vector of parameters and shuttling them to the correct places in the statespace matrices.

The property param_names also needs to be set for each model. This defines the names of the parameters that gather_required_random_variables will look for in the pymc model. It also defines the order of the flat parameter vector that will be passed into update.

Filters

Filters holds all the implementations for the Kalman Filter.

distributions.py

Work in progress. Eventually, this should implement a PyMC distribution wrapper around the kalman filter, so that we can directly sample from it. There's just a lot of wrinkles to iron out.

kalman_filter.py

The actual kalman filters. Currently there are 5 implemented: standard, univariate, single time series, cholesky, and steady state. These need to be bench marked against each other. Single time series should be used when there is only a single observed state (e.g. ARIMA), otherwise use standard. Cholesky is supposd to be faster, but in my limited testing it's not. Not sure why. In principle the univariate filter is the most robust, but it has a scan in a scan, so it's quite slow. I haven't benchmarked it in JAX, though.

kalman_smoother.py

Pytensor implementation of the Kalman smoother. Good for hidden state inference, but split out because not all users will need it. It's a post-estimation thing.

numpy_filter.py

This is a re-implementation of the cholesky kalman filter in pure numpy. Potentially for use in distribution.py. I'm not sure what pm.draw should return when called on a statespace distribution, so it's there as an option.

utilities.py

Shared functions between modules. Currently just holds a helper function to sort scan inputs into sequence and non-sequence. This is needed because if matrices are time-varying, they are sequences, otherwise they are non-sequences. I guess I could always make them sequences and just copy the matrix a bunch of times if its not time-varying, but this is more memory efficient (cope because I already sunk a lot of time into doing it this way?)

Models

This module will hold actual implementations of state space models that users can call. Right now they are fully contained models, following the setup of statsmodels, but I could imagine a better, more modular API. Right now I have VARMA, ARMIA, and local level.

There's also a utility file for shared functions, right now it's a little function that's used in the update function to keep track of slicing up the flat parameter vector.

Utils

A hodge-podge of stuff.

numba_linalg.py

This holds numba implements linear algebra routines with no overload. Currently it's just scipy.linalg.block_diag.

pytensor_scipy.py

This holds a pytensor Ops that should be split off and pulled into pytensor.tensor.slinalg for solving the Discrete Algebraic Riccati equation. This is currently used in the SteadyStateFilter, and can save a lot of time by pre-computing and re-using a single matrix inverse for all time-steps in the kalman filter. It doesn't have a jaxified version, though, so it's not actually that useful right now. Solving AREs is useful in general, though.

simulation.py

This holds numbafied routines for posterior predictive simulation. Since the current implementation returns a pm.Potential for the logp, it's not possible to use the usual posterior predictive sampling machinery, so I resorted to this. I hope it can be removed in the future. It contains separate functions for conditional and unconditional simulation.

Unconditional simulation just applies the observation and state transition equations to an initial state. This is useful for computing theoretical moments of the system, and also for forecasting.

Conditional simulation draws statespace matrices and runs the data through the kalman filter (and, optionally, smoother). This is useful for hidden state inference and missing data interpolation.

Summary and To Do

Basically that's it. There's still a lot to do, but this is at least a start. I hope it can be useful to the community, and we can get it to be a super fast, reliable alternative to the statsmodels statespace module.

Here's a quick, non-exhaustive list of to-dos:

  1. Get rid of the pm.potential term via a distribution wrapper around the kalman filter
  2. Write numba overloads for relevant linear algebra operations so we can use nutpie
  3. Consider a more modular API. Users should, in principle, be able to combine statespace models via block-diagonal concatenation. Then we could have a "seasonal" model, "ARIMA" model, local level model, deterministic trend model, etc., and users could combine them as mod = seasonal_part + trend_part + arima_part.
  4. Add better support for forcasting
  5. Add support for computing theoritical moments
  6. Add support for computing and plotting IRFs
  7. Add parameter transformations relevant to time series models
  8. Plotting functions?
  9. Support for time-series model comparison?

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of small comments

pymc_experimental/statespace/core/statespace.py Outdated Show resolved Hide resolved
pymc_experimental/statespace/core/statespace.py Outdated Show resolved Hide resolved
pymc_experimental/statespace/core/statespace.py Outdated Show resolved Hide resolved
pymc_experimental/statespace/filters/distributions.py Outdated Show resolved Hide resolved
pymc_experimental/statespace/utils/numba_linalg.py Outdated Show resolved Hide resolved
pymc_experimental/tests/statespace/test_VARMAX.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Also need to add relevant entries in the docs

@jessegrabowski
Copy link
Member Author

Big refactor done, maybe have a look at the Making a custom statespace notebook to see how things work now. It's way better i think. First, it's the right way to do it in pytensor -- previously I was just copying a numpy implementation. Second, it gets rid of all the numpy internals in the structural module. Third, it adds unlimited flexibility in how parameters can be manipulated before/after going into the matrices. This is relevant if you want to estimate the length of a frequency period, for example.

I think the tests will pass on this next run (the irony of a commit called "all tests pass" failing CI is not lost on me, pride goeth before the fall), and then I would ask for this to be approved for merging, then I'll open PRs for the notebooks to go into pymc-examples, and for SolveDiscreteARE to go to pytensor. I'll also open an issue with a checklist of "to do" things still outstanding that I can work on (and maybe even drum up some contributors on).

@ricardoV94
Copy link
Member

Tests should pass once https://github.com/pymc-devs/pymc/releases/tag/v5.7.2 gets picked up by the CI

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration with PyMC models looks sweet.

As is usual for big PRs I left a comment about an unimportant part of the code!

pymc_experimental/statespace/models/utilities.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

I suggest you can codecov locally (we don't seem to have it in the CI here :( ) just to see if there's some important lines that aren't being tested.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if there's some part that you want me to inspect more carefully

@jessegrabowski
Copy link
Member Author

Let me know if there's some part that you want me to inspect more carefully

kalman_filter.py is the beating heart, everything else is fancy bibs and bobs. If you can spot any optimizations there, it would offer a lot of improvement across the module. It's the most mathy, but there's also a lot of dumb logic about sorting variables by shape, adding checks, etc. Maybe you can spot some pointless stuff? Also I'm curious what you think about the class inheritance design I went with. The PyMC code base mostly seems to use class constructors (is that the right name for the classmethod(__new__) design pattern? It's still a bit magic to me) for shared functionality.

Also maybe have a look at how I handle missing values/data registration in data_tools.py and PyMCStateSpace.build_statespace_graph. I'm curious if I can put the nan values back in before it goes to sampling? I'm a bit disappointed that the -9999 values ends up in the idata.constant_data group, rather than the NaNs.

Add tests for impulse_response_function and forecast

Add tests for SARIMAX in "interpretable" mode
@ricardoV94
Copy link
Member

ricardoV94 commented Aug 10, 2023

The PyMC code base mostly seems to use class constructors (is that the right name for the classmethod(new) design pattern? It's still a bit magic to me) for shared functionality.

That's an artifact from the fact that our classes are just functions in disguise. Not a pattern that we should try to stick with unless necessary. pymc-devs/pymc#5308

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 10, 2023

Also maybe have a look at how I handle missing values/data registration in data_tools.py and PyMCStateSpace.build_statespace_graph. I'm curious if I can put the nan values back in before it goes to sampling? I'm a bit disappointed that the -9999 values ends up in the idata.constant_data group, rather than the NaNs.

As I mentioned above, I think the clean solution is to store a boolean mask of the nan/non-nan mask, and then the unraveled non-nan entries. Not sure how helpful but you can have a look at the more recent PartialObservedRV: https://github.com/pymc-devs/pymc/blob/d59a960f89873667d6190489ff0e975091e57d10/pymc/distributions/distribution.py#L1166-L1292

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 10, 2023

kalman_filter.py is the beating heart, everything else is fancy bibs and bobs. If you can spot any optimizations there, it would offer a lot of improvement across the module. It's the most mathy, but there's also a lot of dumb logic about sorting variables by shape, adding checks, etc

I could only have a superficial look, not being familiar enough with this kind of models and math :(

The Asserts/SpecifyShapes should be pretty cheap / removable during compilation. The transpositions/concatenation are a tax we pay for every timeseries at the moment. Hopefully they aren't crazy expensive. Transpositions are as cheap as they get, what matters is how data layout plays into subsequent loops, but that's a bit too low-level optimization to look at before we see unreasonable bottlenecks. Concatenations may or not require copying arrays, I don't remember now.

I don't see anything obvious from my myopic perspective :D

@jessegrabowski jessegrabowski merged commit 67a9695 into pymc-devs:main Aug 18, 2023
7 checks passed
@jessegrabowski jessegrabowski deleted the statespace branch September 17, 2023 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants