Structural Time Series (STS) in JAX
This library has a similar to design to tfp.sts, but is built entirely in JAX, and uses the Dynamax library for state-space models. We also include an implementation of the causal impact method. This has a similar to design to tfcausalimpact, but is built entirely in JAX.
To install the latest development branch:
pip install git+
or use
git clone [email protected]:probml/sts-jax.git
cd sts-jax
pip install -e .
The STS model is a linear state space model with a specific structure. In particular,
the latent state
The STS model (with scalar Gaussian observations) takes the form:
$y_t$ : observation (emission) at time$t$ . -
$\sigma^2_t$ : variance of the observation noise. -
$H_t$ : emission matrix, which sums up the contributions of all latent components. -
$u_t = x_t^T \beta$ : regression component from external inputs. -
$F_t$ : fixed transition matrix of the latent dynamics. -
$R_t$ : the selection matrix, which is a subset of columns of base vector$e_i$ , converting the non-singular covariance matrix into the (possibly singular) covariance matrix of the latent state$z_t$ . -
$Q_t$ : non-singular covariance matrix of the latent state, so the dimension of$Q_t$ can be smaller than the dimension of$z_t$ .
The covariance matrix of the latent dynamics model takes the form
More information of STS models can be found in these books:
- "Machine Learning: Advanced Topics", K. Murphy, MIT Press 2023. Available at
- "Time Series Analysis by State Space Methods (2nd edn)", James Durbin, Siem Jan Koopman, Oxford University Press, 2012.
In this library, an STS model is constructed by providing the observed time series and specifying a list of components and the distribution family of the observation. This library implements common STS components including local linear trend component, seasonal component, cycle component, autoregressive component, and regression component. The observed time series can follow either the Gaussian distribution or the Poisson distribution. (Other likelihood functions can also be added.)
Internally, the STS model is converted to the corresponding state space model (SSM) and inference
and learning of parameters are performed on the SSM.
If the observation
The marginal likelihood of
Below we illustrate the API applied to some example datasets.
This example is adapted from the TFP blog. See this file for a runnable version of this demo.
The problem of interest is to forecast electricity demand in Victoria, Australia. The dataset contains hourly record of electricity demand and temperature measurements from the first 8 weeks of 2014. The following plot is the training set of the data, which contains measurements in the first 6 weeks.
We now build a model where the demand linearly depends on the temperature, but also has two seasonal components, and an auto-regressive component.
import sts_jax.structural_time_series.sts_model as sts
hour_of_day_effect = sts.SeasonalDummy(num_seasons=24,
day_of_week_effect = sts.SeasonalTrig(num_seasons=7, num_steps_per_season=24,
temperature_effect = sts.LinearRegression(dim_covariates=1, add_bias=True,
autoregress_effect = sts.Autoregressive(order=1,
# The STS model is constructed by providing the observed time series,
# specifying a list of components and the distribution family of the observations.
model = sts.StructuralTimeSeries(
[hour_of_day_effect, day_of_week_effect, temperature_effect, autoregress_effect],
In this case, we choose to fit the model using MLE.
# Perform the MLE estimation of parameters via SGD implemented in dynamax library.
opt_param, _losses = model.fit_mle(obs_time_series,
We can now plug in the parameters and the future inputs, and use ancestral sampling from the filtered posterior to forecast future observations.
# The 'forecast' method samples the future means and future observations from the
# predictive distribution, conditioned on the parameters of the model.
forecast_means, forecasts = model.forecast(opt_param,
The following plot shows the mean and 95% probability interval of the forecast.
This example is adapted from the TFP blog. See this file for a runnable version of the demo, which is similar to the electricity example.
We can also fit STS models with discrete observations following the Poisson
distribution. Internally, the inference of the latent states
Below we create a synthetic dataset, following this TFP example. See this file for a runnable version of this demo.
import sts_jax.structural_time_series.sts_model as sts
# This example uses a synthetic dataset and the STS model contains only a
# local linear trend component.
trend = sts.LocalLinearTrend()
model = sts.StructuralTimeSeries([trend],
# Fit the model using HMC algorithm
param_samples, _log_probs = model.fit_hmc(num_samples=200,
# Forecast into the future given samples of parameters returned by the HMC algorithm.
forecasts = model.forecast(param_samples, obs_time_series, num_forecast_steps)[1]
The TFP approach to STS with non-conjugate likelihoods is to perform
HMC on the joint distribution of the latent states
The causal impact method is implemented on top of the STS-JAX package.
Below we show an example, where Y is the output time series and X is a parallel
set of input covariates. We notice a sudden change in the response variable at time
This is how we run inference:
from sts_jax.causal_impact.causal_impact import causal_impact
# The causal impact is inferred by providing the target time series and covariates,
# specifying the intervention time and the distribution family of the observation.
# If the STS model is not given, an STS model with only a local linear trend component
# in addition to the regression component is constructed by default internally.
impact = causal_impact(obs_time_series,
The format of the output from our causal impact code follows that of the R package CausalImpact, and is shown below.
Posterior inference of the causal impact:
Average Cumulative
Actual 129.93 3897.88
Prediction (s.d.) 120.01 (2.04) 3600.42 (61.31)
95% CI [114.82, 123.07] [3444.72, 3692.09]
Absolute effect (s.d.) 9.92 (2.04) 297.45 (61.31)
95% CI [6.86, 15.11] [205.78, 453.16]
Relative effect (s.d.) 8.29% (1.89%) 8.29% (1.89%)
95% CI [5.57%, 13.16%] [5.57%, 13.16%]
Posterior tail-area probability p: 0.0050
Posterior prob of a causal effect: 99.50%
Authors: Xinlong Xi, Kevin Murphy.
MIT License. 2022