Skip to content

Commit 903cc8b

Browse files
seabbsjamesmbaazam
andauthored
Optimise convolutions (#745)
* refactor convolution code * use tests to check optimisationn * optimise get_rev_delats * revert setup changes * revert major delays rewrite * revmove use of reverse_mf * Update NEWS.md * Add new line in test * Update inst/stan/functions/convolve.stan Co-authored-by: James Azam <[email protected]> --------- Co-authored-by: James Azam <[email protected]>
1 parent 5ce0043 commit 903cc8b

File tree

6 files changed

+120
-65
lines changed

6 files changed

+120
-65
lines changed

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
1313
- When defining probability distributions these can now be truncated using the `tolerance` argument
1414
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @.
15+
- Optimised convolution code to take into account the relative length of the vectors being convolved. See #745 by @seabbs and reviewed by @jamesmbaazam.
1516
- Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam.
1617
- A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs.
1718

inst/stan/functions/convolve.stan

+89-25
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,100 @@
1-
// convolve two vectors as a backwards dot product
2-
// y vector should be reversed
3-
// limited to the length of x and backwards looking for x indexes
1+
/**
2+
* Calculate convolution indices for the case where s <= xlen
3+
*
4+
* @param s Current position in the output vector
5+
* @param xlen Length of the x vector
6+
* @param ylen Length of the y vector
7+
* @return An array of integers: {start_x, end_x, start_y, end_y}
8+
*/
9+
array[] int calc_conv_indices_xlen(int s, int xlen, int ylen) {
10+
int s_minus_ylen = s - ylen;
11+
int start_x = max(1, s_minus_ylen + 1);
12+
int end_x = s;
13+
int start_y = max(1, 1 - s_minus_ylen);
14+
int end_y = ylen;
15+
return {start_x, end_x, start_y, end_y};
16+
}
17+
18+
/**
19+
* Calculate convolution indices for the case where s > xlen
20+
*
21+
* @param s Current position in the output vector
22+
* @param xlen Length of the x vector
23+
* @param ylen Length of the y vector
24+
* @return An array of integers: {start_x, end_x, start_y, end_y}
25+
*/
26+
array[] int calc_conv_indices_len(int s, int xlen, int ylen) {
27+
int s_minus_ylen = s - ylen;
28+
int start_x = max(1, s_minus_ylen + 1);
29+
int end_x = xlen;
30+
int start_y = max(1, 1 - s_minus_ylen);;
31+
int end_y = ylen + xlen - s;
32+
return {start_x, end_x, start_y, end_y};
33+
}
34+
35+
/**
36+
* Convolve a vector with a reversed probability mass function.
37+
*
38+
* This function performs a discrete convolution of two vectors, where the second vector
39+
* is assumed to be an already reversed probability mass function.
40+
*
41+
* @param x The input vector to be convolved.
42+
* @param y The already reversed probability mass function vector.
43+
* @param len The desired length of the output vector.
44+
* @return A vector of length `len` containing the convolution result.
45+
* @throws If `len` is not of equal length to the sum of the lengths of `x` and `y`.
46+
*/
447
vector convolve_with_rev_pmf(vector x, vector y, int len) {
5-
int xlen = num_elements(x);
6-
int ylen = num_elements(y);
7-
vector[len] z;
8-
if (xlen + ylen <= len) {
9-
reject("convolve_with_rev_pmf: len is longer then x and y combined");
10-
}
11-
for (s in 1:len) {
12-
z[s] = dot_product(
13-
x[max(1, (s - ylen + 1)):min(s, xlen)],
14-
y[max(1, ylen - s + 1):min(ylen, ylen + xlen - s)]
15-
);
48+
int xlen = num_elements(x);
49+
int ylen = num_elements(y);
50+
vector[len] z;
51+
52+
if (xlen + ylen - 1 < len) {
53+
reject("convolve_with_rev_pmf: len is longer than x and y convolved");
54+
}
55+
56+
if (xlen > len) {
57+
reject("convolve_with_rev_pmf: len is shorter than x");
58+
}
59+
60+
for (s in 1:xlen) {
61+
array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen);
62+
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
63+
}
64+
65+
if (len > xlen) {
66+
for (s in (xlen + 1):len) {
67+
array[4] int indices = calc_conv_indices_len(s, xlen, ylen);
68+
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
1669
}
17-
return(z);
1870
}
71+
72+
return z;
73+
}
1974

20-
21-
// convolve latent infections to reported (but still unobserved) cases
75+
/**
76+
* Convolve infections to reported cases.
77+
*
78+
* This function convolves a vector of infections with a reversed delay
79+
* distribution to produce a vector of reported cases.
80+
*
81+
* @param infections A vector of infection counts.
82+
* @param delay_rev_pmf A vector representing the reversed probability mass
83+
* function of the delay distribution.
84+
* @param seeding_time The number of initial time steps to exclude from the
85+
* output.
86+
* @return A vector of reported cases, starting from `seeding_time + 1`.
87+
*/
2288
vector convolve_to_report(vector infections,
2389
vector delay_rev_pmf,
2490
int seeding_time) {
2591
int t = num_elements(infections);
26-
vector[t - seeding_time] reports;
27-
vector[t] unobs_reports = infections;
2892
int delays = num_elements(delay_rev_pmf);
29-
if (delays) {
30-
unobs_reports = convolve_with_rev_pmf(unobs_reports, delay_rev_pmf, t);
31-
reports = unobs_reports[(seeding_time + 1):t];
32-
} else {
33-
reports = infections[(seeding_time + 1):t];
93+
94+
if (delays == 0) {
95+
return infections[(seeding_time + 1):t];
3496
}
35-
return(reports);
97+
98+
vector[t] unobs_reports = convolve_with_rev_pmf(infections, delay_rev_pmf, t);
99+
return unobs_reports[(seeding_time + 1):t];
36100
}

inst/stan/functions/delays.stan

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ vector get_delay_rev_pmf(
4343
pmf[1:new_len] = new_variable_pmf;
4444
} else { // subsequent delay to be convolved
4545
pmf[1:new_len] = convolve_with_rev_pmf(
46-
pmf[1:current_len], reverse_mf(new_variable_pmf), new_len
46+
pmf[1:current_len], reverse(new_variable_pmf), new_len
4747
);
4848
}
4949
} else { // nonparametric
@@ -54,7 +54,7 @@ vector get_delay_rev_pmf(
5454
pmf[1:new_len] = delay_np_pmf[start:end];
5555
} else { // subsequent delay to be convolved
5656
pmf[1:new_len] = convolve_with_rev_pmf(
57-
pmf[1:current_len], reverse_mf(delay_np_pmf[start:end]), new_len
57+
pmf[1:current_len], reverse(delay_np_pmf[start:end]), new_len
5858
);
5959
}
6060
}
@@ -70,7 +70,7 @@ vector get_delay_rev_pmf(
7070
pmf = cumulative_sum(pmf);
7171
}
7272
if (reverse_pmf) {
73-
pmf = reverse_mf(pmf);
73+
pmf = reverse(pmf);
7474
}
7575
return pmf;
7676
}

inst/stan/functions/pmfs.stan

-33
Original file line numberDiff line numberDiff line change
@@ -30,36 +30,3 @@ vector discretised_pmf(vector params, int n, int dist) {
3030
}
3131
return(exp(lpmf));
3232
}
33-
34-
// reverse a mf
35-
vector reverse_mf(vector pmf) {
36-
int pmf_length = num_elements(pmf);
37-
vector[pmf_length] rev_pmf;
38-
for (d in 1:pmf_length) {
39-
rev_pmf[d] = pmf[pmf_length - d + 1];
40-
}
41-
return rev_pmf;
42-
}
43-
44-
vector rev_seq(int base, int len) {
45-
vector[len] seq;
46-
for (i in 1:len) {
47-
seq[i] = len + base - i;
48-
}
49-
return(seq);
50-
}
51-
52-
real rev_pmf_mean(vector rev_pmf, int base) {
53-
int len = num_elements(rev_pmf);
54-
vector[len] rev_pmf_seq = rev_seq(base, len);
55-
return(dot_product(rev_pmf_seq, rev_pmf));
56-
}
57-
58-
real rev_pmf_var(vector rev_pmf, int base, real mean) {
59-
int len = num_elements(rev_pmf);
60-
vector[len] rev_pmf_seq = rev_seq(base, len);
61-
for (i in 1:len) {
62-
rev_pmf_seq[i] = pow(rev_pmf_seq[i], 2);
63-
}
64-
return(dot_product(rev_pmf_seq, rev_pmf) - pow(mean, 2));
65-
}

tests/testthat/test-stan-convole.R

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11
skip_on_cran()
22
skip_on_os("windows")
33

4-
test_that("convolve can combine two pmfs as expected", {
4+
# Test calc_conv_indices_xlen function
5+
test_that("calc_conv_indices_xlen calculates correct indices", {
6+
expect_equal(calc_conv_indices_xlen(1, 5, 3), c(1, 1, 3, 3))
7+
expect_equal(calc_conv_indices_xlen(3, 5, 3), c(1, 3, 1, 3))
8+
expect_equal(calc_conv_indices_xlen(5, 5, 3), c(3, 5, 1, 3))
9+
})
10+
11+
# Test calc_conv_indices_len function
12+
test_that("calc_conv_indices_len calculates correct indices", {
13+
expect_equal(calc_conv_indices_len(6, 5, 3), c(4, 5, 1, 2))
14+
expect_equal(calc_conv_indices_len(7, 5, 3), c(5, 5, 1, 1))
15+
expect_equal(calc_conv_indices_len(8, 5, 3), c(6, 5, 1, 0))
16+
})
17+
18+
test_that("convolve_with_rev_pmf can combine two pmfs as expected", {
519
expect_equal(
620
convolve_with_rev_pmf(c(0.1, 0.2, 0.7), rev(c(0.1, 0.2, 0.7)), 5),
721
c(0.01, 0.04, 0.18, 0.28, 0.49),
@@ -14,7 +28,7 @@ test_that("convolve can combine two pmfs as expected", {
1428
)
1529
})
1630

17-
test_that("convolve performs the same as a numerical convolution", {
31+
test_that("convolve_with_rev_pmf performs the same as a numerical convolution", {
1832
# Sample and analytical PMFs for two Poisson distributions
1933
x <- rpois(10000, 3)
2034
xpmf <- dpois(0:20, 3)
@@ -32,7 +46,7 @@ test_that("convolve performs the same as a numerical convolution", {
3246
expect_lte(sum(abs(conv_cdf - cdf)), 0.1)
3347
})
3448

35-
test_that("convolve_dot_product can combine vectors as we expect", {
49+
test_that("convolve_with_rev_pmf can combine vectors as we expect", {
3650
expect_equal(
3751
convolve_with_rev_pmf(c(0.1, 0.2, 0.7), rev(c(0.1, 0.2, 0.7)), 3),
3852
c(0.01, 0.04, 0.18),
@@ -54,3 +68,12 @@ test_that("convolve_dot_product can combine vectors as we expect", {
5468
x
5569
)
5670
})
71+
72+
test_that("convolve_dot_product can combine two vectors where x > y and len = x", {
73+
x <- c(1, 2, 3, 4, 5)
74+
y <- c(1, 2, 3)
75+
expect_equal(
76+
convolve_with_rev_pmf(x, rev(y), 5),
77+
c(1, 4, 10, 16, 22)
78+
)
79+
})

tests/testthat/test-stan-secondary.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ skip_on_os("windows")
44
# test primary reports and observations
55
reports <- rep(10, 20)
66
obs <- rep(4, 20)
7-
delay_rev_pmf <- reverse_mf(discretised_pmf(c(log(3), 0.1), 5, 0))
7+
delay_rev_pmf <- rev(discretised_pmf(c(log(3), 0.1), 5, 0))
88
scaled <- reports * 0.1
99
convolved <- rep(1e-5, 20) + convolve_to_report(scaled, delay_rev_pmf, 0)
1010

0 commit comments

Comments
 (0)