The hardware and bandwidth for this mirror is donated by METANET, the Webhosting and Full Service-Cloud Provider.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]metanet.ch.
The following examples walk through using chkptstanr with the Stan
The basic idea is to (1) write a custom Stan model
(done by the user), (2) fit the model with cmdstanr (with
the desired number of checkpoints), and then (3) return a
cmststanr
object. All but step (1) is done internally, so
the workflow is very similar to using cmdstanr.
library(chkptstanr)
library(posterior)
library(bayesplot)
The initial overhead is to create a folder that will store the checkpoints, i.e.,
<- create_folder(folder_name = "chkpt_folder_m1") path
Next is the Stan model:
<- "
stan_code data {
int<lower=0> n;
real y[n];
real<lower=0> sigma[n];
}
parameters {
real mu;
real<lower=0> tau;
vector[n] eta;
}
transformed parameters {
vector[n] theta;
theta = mu + tau * eta;
}
model {
target += normal_lpdf(eta | 0, 1);
target += normal_lpdf(y | theta, sigma);
}
"
When using chkpt_stan()
, this requires supplying a list
to the data
argument, much like using rstan.
<- schools.data <- list(
stan_data n = 8,
y = c(28, 8, -3, 7, -1, 1, 18, 12),
sigma = c(15, 10, 16, 11, 9, 11, 10, 18)
)
To show the basic idea of checkpointing, the following was stopped after 2 checkpoints.
<- chkpt_stan(model_code = stan_code,
fit_m1 data = stan_data,
iter_warmup = 1000,
iter_sampling = 1000,
iter_per_chkpt = 250,
path = path)
#> Compiling Stan program...
#> Initial Warmup (Typical Set)
#> Chkpt: 1 / 8; Iteration: 250 / 2000 (warmup)
#> Chkpt: 2 / 8; Iteration: 500 / 2000 (warmup)
To finish the remaining 6 checkpoints run the same code, i.e.,
<- chkpt_stan(model_code = stan_code,
fit_m1 data = stan_data,
iter_warmup = 1000,
iter_sampling = 1000,
iter_per_chkpt = 250,
path = path)
#> Sampling next checkpoint
#> Chkpt: 3 / 8; Iteration: 750 / 2000 (warmup)
#> Chkpt: 4 / 8; Iteration: 1000 / 2000 (warmup)
#> Chkpt: 5 / 8; Iteration: 1250 / 2000 (sample)
#> Chkpt: 6 / 8; Iteration: 1500 / 2000 (sample)
#> Chkpt: 7 / 8; Iteration: 1750 / 2000 (sample)
#> Chkpt: 8 / 8; Iteration: 2000 / 2000 (sample)
#> Checkpointing complete
Each checkpoint contains 250 draws from the posterior. These need to
be combined with combine_chkpt_draws()
, i.e.,
draws <- combine_chkpt_draws(fit_m1)
We developed chkptstanr to work seamlessly with the
Stan ecosystem. The object draws
has been
constructed to mimic what is provided when using
cmdstanr directly.
combine_chkpt_draws(fit_m1)
#> # A draws_array: 1000 iterations, 2 chains, and 19 variables
#> , , variable = lp__
#>
#> chain
#> iteration 1 2
#> 1 -34 -43
#> 2 -37 -41
#> 3 -36 -39
#> 4 -38 -38
#> 5 -38 -41
#>
#> , , variable = mu
#>
#> chain
#> iteration 1 2
#> 1 5.2 2.6
#> 2 11.3 6.7
#> 3 -2.7 5.3
#> 4 -2.9 3.7
#> 5 -2.7 14.2
#>
#> , , variable = tau
#>
#> chain
#> iteration 1 2
#> 1 23.3 2.61
#> 2 6.7 0.21
#> 3 12.7 4.44
#> 4 21.1 7.29
#> 5 18.8 10.94
#>
#> , , variable = eta[1]
#>
#> chain
#> iteration 1 2
#> 1 0.10 -0.61
#> 2 0.89 -0.87
#> 3 1.62 0.83
#> 4 1.99 0.84
#> 5 -0.16 1.22
#>
#> # ... with 995 more iterations, and 15 more variables
draws
can then be used with the R
package
posterior
::summarise_draws(draws)
posterior
#> # A tibble: 19 x 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lp__ -39.5 -39.2 2.59 2.58 -44.2 -35.9 1.00 640. 1008.
#> 2 mu 7.77 7.92 5.48 5.10 -1.43 16.0 1.01 530. 325.
#> 3 tau 6.82 5.32 5.75 4.71 0.434 18.7 1.00 649. 658.
#> 4 eta[1] 0.383 0.413 0.929 0.909 -1.20 1.87 1.00 1650. 1233.
#> 5 eta[2] -0.00335 -0.00816 0.841 0.814 -1.34 1.40 1.00 1443. 1307.
#> 6 eta[3] -0.176 -0.174 0.931 0.906 -1.67 1.42 1.00 1829. 1424.
#> 7 eta[4] -0.00521 0.000856 0.862 0.841 -1.47 1.39 1.00 1565. 1407.
#> 8 eta[5] -0.312 -0.350 0.873 0.835 -1.72 1.24 1.00 1661. 1616.
#> 9 eta[6] -0.193 -0.190 0.889 0.909 -1.59 1.28 1.00 1915. 1404.
#> 10 eta[7] 0.387 0.358 0.876 0.864 -1.09 1.81 1.00 1574. 1370.
#> 11 eta[8] 0.0805 0.0611 0.970 0.960 -1.51 1.66 1.00 1031. 1236.
#> 12 theta[1] 11.5 10.2 8.29 6.99 0.268 26.4 1.00 1042. 728.
#> 13 theta[2] 7.87 7.87 6.20 5.66 -2.27 17.8 1.00 1549. 1515.
#> 14 theta[3] 6.01 6.63 8.25 6.63 -8.69 18.1 1.00 1102. 1075.
#> 15 theta[4] 7.75 7.76 6.65 5.96 -3.06 18.9 1.00 1674. 1210.
#> 16 theta[5] 5.05 5.70 6.44 5.75 -7.06 14.4 1.00 1405. 1416.
#> 17 theta[6] 6.21 6.60 6.92 6.15 -5.98 16.9 1.00 1890. 1195.
#> 18 theta[7] 10.8 10.1 6.71 6.03 0.992 23.1 1.00 1497. 1767.
#> 19 theta[8] 8.35 8.41 7.72 6.66 -3.88 20.7 1.00 1081. 1075.
The popular R
package bayesplot can
also be used.
::mcmc_trace(draws) +
bayesplotgeom_vline(xintercept = seq(0, 1000, 250),
alpha = 0.25,
size = 2)
This vertical lines are placed at each checkpoint.
These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.