Bayesian analysis with brms and marginaleffects

The marginaleffects package offers convenience functions to compute and display predictions, contrasts, and marginal effects from bayesian models estimated by the brms package. To compute these quantities, marginaleffects relies on workshorse functions from the brms package to draw from the posterior distribution. The type of draws used is controlled by using the type argument of the predictions or marginaleffects functions:

The predictions and marginaleffects functions can also pass additional arguments to the brms prediction functions via the ... ellipsis. For example, if mod is a mixed-effects model, then this command will compute 10 draws from the posterior predictive distribution, while ignoring all group-level effects:

predictions(mod, type = "prediction", ndraws = 10, re_formula = NA)

See the brms documentation for a list of available arguments:

?brms::posterior_epred
?brms::posterior_linpred
?brms::posterior_predict

Note that support for brms will be available in version 0.3.0 of marginaleffects. Until that version is released on CRAN, it can be installed from the development repository:

library(remotes)
install_github("vincentarelbundock/marginaleffects")

Logistic regression with multiplicative interactions

Load libraries and download data on passengers of the Titanic from the Rdatasets archive:

library(brms)
library(marginaleffects)
library(ggplot2)
library(ggdist)

dat <- read.csv("https://vincentarelbundock.github.io/Rdatasets/csv/carData/TitanicSurvival.csv")
dat$survived <- ifelse(dat$survived == "yes", 1, 0)
dat$woman <- ifelse(dat$sex == "female", 1, 0)

Fit a logit model with a multiplicative interaction:

mod <- brm(survived ~ woman * age + passengerClass,
           family = bernoulli(link = "logit"),
           backend = "cmdstanr", cores = 4,
           data = dat)

Adjusted predictions

We can compute adjusted predicted values of the outcome variable (i.e., probability of survival aboard the Titanic) using the predictions function. By default, this function calculates predictions for each row of the dataset:

pred <- predictions(mod)
head(pred)
#>   rowid     type predicted survived woman     age passengerClass  conf.low
#> 1     4 response 0.9361746        1     1 29.0000            1st 0.9088526
#> 2     5 response 0.8477418        1     0  0.9167            1st 0.7529371
#> 3     6 response 0.9426519        0     1  2.0000            1st 0.9010494
#> 4     7 response 0.5125855        0     0 30.0000            1st 0.4309135
#> 5     8 response 0.9371363        0     1 25.0000            1st 0.9091038
#> 6     9 response 0.2725831        1     0 48.0000            1st 0.2014191
#>   conf.high
#> 1 0.9595733
#> 2 0.9225402
#> 3 0.9733116
#> 4 0.5987987
#> 5 0.9606767
#> 6 0.3455888

To visualize the relationship between the outcome and one of the regressors, we can plot conditional adjusted predictions with the plot_cap function:

plot_cap(mod, condition = "age")

Compute adjusted predictions for some user-specified values of the regressors, using the newdata argument and the typical function:

pred <- predictions(mod, newdata = typical(woman = 0:1, passengerClass = c("1st", "2nd", "3rd")))
pred
#>   rowid     type  predicted      age woman passengerClass   conf.low conf.high
#> 1     4 response 0.51419465 29.88113     0            1st 0.43366926 0.6018784
#> 2     5 response 0.93603218 29.88113     1            1st 0.90847257 0.9594656
#> 3     6 response 0.20265337 29.88113     0            2nd 0.15114590 0.2613733
#> 4     7 response 0.77774811 29.88113     1            2nd 0.71671854 0.8378469
#> 5     8 response 0.08775081 29.88113     0            3rd 0.06373833 0.1130598
#> 6     9 response 0.57136784 29.88113     1            3rd 0.50042547 0.6465660

The get_posterior_draws function samples from the posterior distribution of the model, and produces a data frame with drawid and draw columns.

pred <- get_posterior_draws(pred)
head(pred)
#>        type rowid_internal drawid      draw rowid predicted      age woman
#> 1: response              1      1 0.5014001     4 0.5141946 29.88113     0
#> 2: response              1      2 0.4514729     4 0.5141946 29.88113     0
#> 3: response              1      3 0.4898628     4 0.5141946 29.88113     0
#> 4: response              1      4 0.5349225     4 0.5141946 29.88113     0
#> 5: response              1      5 0.5007135     4 0.5141946 29.88113     0
#> 6: response              1      6 0.5133510     4 0.5141946 29.88113     0
#>    passengerClass  conf.low conf.high
#> 1:            1st 0.4336693 0.6018784
#> 2:            1st 0.4336693 0.6018784
#> 3:            1st 0.4336693 0.6018784
#> 4:            1st 0.4336693 0.6018784
#> 5:            1st 0.4336693 0.6018784
#> 6:            1st 0.4336693 0.6018784

This “long” format makes it easy to plots results:

ggplot(pred, aes(x = draw, fill = factor(woman))) +
    geom_density() +
    facet_grid(~ passengerClass, labeller = label_both) +
    labs(x = "Predicted probability of survival", y = "", fill = "Woman")

Marginal effects

Use marginaleffects() to compute marginal effects (slopes of the regression equation) for each row of the dataset, and use summary() to compute “Average Marginal Effects”, that is, the average of all observation-level marginal effects:

mfx <- marginaleffects(mod)
summary(mfx)
#> Average marginal effects 
#>       type              Term    Effect     2.5 %    97.5 %
#> 1 response               age -0.005224 -0.007948 -0.002612
#> 2 response passengerClass2nd  0.021671 -0.029946  0.074093
#> 3 response passengerClass3rd -0.129620 -0.169244 -0.092489
#> 4 response             woman  0.365452  0.290606  0.443547
#> 
#> Model type:  brmsfit 
#> Prediction type:  response

Compute marginal effects with some regressors fixed at user-specified values, and other regressors held at their means:

marginaleffects(mod, newdata = typical(woman = 1, passengerClass = "1st"))
#>   rowid     type              term          dydx     conf.low     conf.high
#> 1     1 response             woman  0.1567983435  0.111834123  0.2095589106
#> 2     1 response               age -0.0002322851 -0.001331806  0.0009796472
#> 3     1 response passengerClass2nd -0.1565305092 -0.218441404 -0.1008674721
#> 4     1 response passengerClass3rd -0.3632796811 -0.441526785 -0.2954314481
#>        age woman passengerClass
#> 1 29.88113     1            1st
#> 2 29.88113     1            1st
#> 3 29.88113     1            1st
#> 4 29.88113     1            1st

Compute and plot conditional marginal effects:

plot_cme(mod, effect = "woman", condition = "age")

The get_posterior_draws produces a dataset with drawid and draw columns:

draws <- get_posterior_draws(mfx)

dim(draws)
#> [1] 16736000       13

head(draws)
#>        type rowid_internal drawid      draw rowid  term      dydx  conf.low
#> 1: response              1      1 0.1583305     1 woman 0.1536635 0.1103399
#> 2: response              1      2 0.1925169     1 woman 0.1536635 0.1103399
#> 3: response              1      3 0.1660440     1 woman 0.1536635 0.1103399
#> 4: response              1      4 0.1508484     1 woman 0.1536635 0.1103399
#> 5: response              1      5 0.1536475     1 woman 0.1536635 0.1103399
#> 6: response              1      6 0.1725165     1 woman 0.1536635 0.1103399
#>    conf.high survived woman age passengerClass
#> 1: 0.2068255        1     1  29            1st
#> 2: 0.2068255        1     1  29            1st
#> 3: 0.2068255        1     1  29            1st
#> 4: 0.2068255        1     1  29            1st
#> 5: 0.2068255        1     1  29            1st
#> 6: 0.2068255        1     1  29            1st

We can use this dataset to plot our results. For example, to plot the posterior density of the marginal effect of age when the woman variable is equal to 0 or 1:

mfx <- marginaleffects(mod,
                       variables = "age",
                       newdata = typical(woman = 0:1)) |>
       get_posterior_draws()

ggplot(mfx, aes(x = draw, fill = factor(woman))) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Marginal Effect of Age on Survival",
         y = "Posterior density",
         fill = "Woman")

Random effects model

This section replicates some of the analyses of a random effects model published in Andrew Heiss’ blog post: “A guide to correctly calculating posterior predictions and average marginal effects with multilievel Bayesian models.” The objective is mainly to illustrate the use of marginaleffects. Please refer to the original post for a detailed discussion of the quantities computed below.

Load libraries and clean data:

remotes::install_github("vdeminstitute/vdemdata")
library(vdemdata)
library(tidyverse)
library(marginaleffects)
library(brms)
library(ggdist)
library(patchwork)

vdem_2015 <- vdem %>%
  select(country_name, country_text_id, year, region = e_regionpol_6C,
         media_index = v2xme_altinf, party_autonomy_ord = v2psoppaut_ord,
         polyarchy = v2x_polyarchy, civil_liberties = v2x_civlib) %>%
  filter(year == 2015) %>%
  mutate(party_autonomy = party_autonomy_ord >= 3,
         party_autonomy = ifelse(is.na(party_autonomy), FALSE, party_autonomy)) %>%
  mutate(region = factor(region,
                         labels = c("Eastern Europe and Central Asia",
                                    "Latin America and the Caribbean",
                                    "Middle East and North Africa",
                                    "Sub-Saharan Africa",
                                    "Western Europe and North America",
                                    "Asia and Pacific")))

Fit a basic model:

mod <- brm(
  bf(media_index ~ party_autonomy + civil_liberties + (1 | region),
     phi ~ (1 | region)),
  data = vdem_2015,
  family = Beta(),
  control = list(adapt_delta = 0.9),
  backend = "cmdstanr", cores = 4,
  seed = 12345)

Posterior predictions

To compute posterior predictions for specific values of the regressors, we use the newdata argument and the typical function. We also use the type argument to compute two types of predictions: accounting for residual (observation-level) residual variance (prediction) or ignoring it (response).

pred <- predictions(mod,
                    type = c("response", "prediction"),
                    newdata = typical(party_autonomy = c(TRUE, FALSE),
                                      civil_liberties = .5,
                                      region = "Middle East and North Africa"))
pred
#>   rowid       type predicted party_autonomy civil_liberties
#> 1     7   response 0.6215307           TRUE             0.5
#> 2     8   response 0.3683854          FALSE             0.5
#> 3     7 prediction 0.6363109           TRUE             0.5
#> 4     8 prediction 0.3462030          FALSE             0.5
#>                         region   conf.low conf.high
#> 1 Middle East and North Africa 0.52950493 0.7119285
#> 2 Middle East and North Africa 0.27851895 0.4580605
#> 3 Middle East and North Africa 0.24836232 0.9769307
#> 4 Middle East and North Africa 0.03544802 0.7485465

Extract posterior draws and plot them:

pred <- get_posterior_draws(pred)

ggplot(pred, aes(x = draw, fill = party_autonomy)) +
    stat_halfeye(alpha = .5) +
    facet_wrap(~ type) +
    labs(x = "Media index (predicted)", 
         y = "Posterior density",
         fill = "Party autonomy")

Marginal effects and contrasts

As noted in the Marginal Effects vignette, there should be one distinct marginal effect for each combination of regressor values. Here, we consider only one combination of regressor values, where region is “Middle East and North Africa”, and civil_liberties is 0.5. Then, we calculate the mean of the posterior distribution of marginal effects:

mfx <- marginaleffects(mod,
                       newdata = typical(civil_liberties = .5,
                                         region = "Middle East and North Africa"))
mfx
#>   rowid     type               term      dydx  conf.low conf.high
#> 1     1 response party_autonomyTRUE 0.2499799 0.1675054 0.3281245
#> 2     1 response    civil_liberties 0.8172178 0.6342737 1.0107915
#>   party_autonomy civil_liberties                       region
#> 1           TRUE             0.5 Middle East and North Africa
#> 2           TRUE             0.5 Middle East and North Africa

Use the get_posterior_draws() to extract draws from the posterio distribution of marginal effects, and plot them:

mfx <- get_posterior_draws(mfx)

ggplot(mfx, aes(x = draw, y = term)) +
  stat_halfeye() +
  labs(x = "Marginal effect", y = "")

Plot marginal effects, conditional on a regressor:

plot_cme(mod,
         effect = "civil_liberties",
         condition = "party_autonomy")

Continuous predictors

pred <- predictions(mod,
                    newdata = typical(party_autonomy = FALSE,
                                      region = "Middle East and North Africa",
                                      civil_liberties = seq(0, 1, by = 0.05))) |>
        get_posterior_draws()


ggplot(pred, aes(x = civil_liberties, y = draw)) +
    stat_lineribbon() +
    scale_fill_brewer(palette = "Reds") +
    labs(x = "Civil liberties",
         y = "Media index (predicted)",
         fill = "")

The slope of this line for different values of civil liberties can be obtained with:

mfx <- marginaleffects(mod,
                       newdata = typical(civil_liberties = c(.2, .5, .8),
                                         party_autonomy = FALSE,
                                         region = "Middle East and North Africa"),
                       variables = "civil_liberties")
mfx
#>   rowid     type            term      dydx  conf.low conf.high civil_liberties
#> 1     1 response civil_liberties 0.4899448 0.3561016 0.6277562             0.2
#> 2     2 response civil_liberties 0.8090324 0.6133495 0.9922249             0.5
#> 3     3 response civil_liberties 0.8073873 0.6826792 0.9367713             0.8
#>   party_autonomy                       region
#> 1          FALSE Middle East and North Africa
#> 2          FALSE Middle East and North Africa
#> 3          FALSE Middle East and North Africa

And plotted:

mfx <- get_posterior_draws(mfx)

ggplot(mfx, aes(x = draw, fill = factor(civil_liberties))) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Marginal effect of Civil Liberties on Media Index",
         y = "Posterior density",
         fill = "Civil liberties")

The marginaleffects function can use the ellipsis (...) to push any argument forward to the posterior_predict function. This can alter the types of predictions returned. For example, the re_formula=NA argument of the posterior_predict.brmsfit method will compute marginaleffects without including any group-level effects:

mfx <- marginaleffects(mod,
                       newdata = typical(civil_liberties = c(.2, .5, .8),
                                         party_autonomy = FALSE,
                                         region = "Middle East and North Africa"),
                       variables = "civil_liberties",
                       re_formula = NA) |>
       get_posterior_draws()

ggplot(mfx, aes(x = draw, fill = factor(civil_liberties))) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Marginal effect of Civil Liberties on Media Index",
         y = "Posterior density",
         fill = "Civil liberties")

Global grand mean

pred <- predictions(mod,
                    re_formula = NA,
                    newdata = typical(party_autonomy = c(TRUE, FALSE))) |>
        get_posterior_draws()

mfx <- marginaleffects(mod,
                       re_formula = NA,
                       variables = "party_autonomy") |>
       get_posterior_draws()

plot1 <- ggplot(pred, aes(x = draw, fill = party_autonomy)) +
         stat_halfeye(slab_alpha = .5) +
         labs(x = "Media index (Predicted)",
              y = "Posterior density",
              fill = "Party autonomy")

plot2 <- ggplot(mfx, aes(x = draw)) +
         stat_halfeye(slab_alpha = .5)  +
         labs(x = "Contrast: Party autonomy TRUE - FALSE",
              y = "",
              fill = "Party autonomy")

# combine plots using the `patchwork` package
plot1 + plot2

Region-specific predictions and contrasts

Predicted media index by region and level of civil liberties:

pred <- predictions(mod,
                    newdata = typical(region = vdem_2015$region,
                                      party_autonomy = FALSE, 
                                      civil_liberties = seq(0, 1, length.out = 100))) |> 
        get_posterior_draws()

ggplot(pred, aes(x = civil_liberties, y = draw)) +
    stat_lineribbon() +
    scale_fill_brewer(palette = "Reds") +
    facet_wrap(~ region) +
    labs(x = "Civil liberties",
         y = "Media index (predicted)",
         fill = "")

Predicted media index by region and level of civil liberties:

pred <- predictions(mod,
                    newdata = typical(region = vdem_2015$region,
                                      civil_liberties = c(.2, .8),
                                      party_autonomy = FALSE)) |>
        get_posterior_draws()

ggplot(pred, aes(x = draw, fill = factor(civil_liberties))) +
    stat_halfeye(slab_alpha = .5) +
    facet_wrap(~ region) +
    labs(x = "Media index (predicted)",
         y = "Posterior density",
         fill = "Civil liberties")

Predicted media index by region and party autonomy:

pred <- predictions(mod,
                    newdata = typical(region = vdem_2015$region,
                                      party_autonomy = c(TRUE, FALSE),
                                      civil_liberties = .5)) |>
        get_posterior_draws()

ggplot(pred, aes(x = draw, y = region , fill = party_autonomy)) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Media index (predicted)",
         y = "",
         fill = "Party autonomy")

TRUE/FALSE contrasts (marginal effects) of party autonomy by region:

mfx <- marginaleffects(mod,
                       variables = "party_autonomy",
                       newdata = typical(region = vdem_2015$region,
                                         civil_liberties = .5)) |>
        get_posterior_draws()

ggplot(mfx, aes(x = draw, y = region , fill = party_autonomy)) +
    stat_halfeye(slab_alpha = .5) +
    labs(x = "Media index (predicted)",
         y = "",
         fill = "Party autonomy")

Hypothetical groups

We can also obtain predictions or marginal effects for a hypothetical group instead of one of the observed regions. To achieve this, we create a dataset with NA in the region column. Then, we call the marginaleffects or predictions functions with the re_formula=NULL argument. This argument is pushed through via the ellipsis (...) to the predict function of brms:

dat <- data.frame(civil_liberties = .5,
                  party_autonomy = FALSE,
                  region = "Atlantis")

marginaleffects(mod,
                variables = "party_autonomy",
                type = "response",
                newdata = dat,
                allow_new_levels = TRUE,
                re_formula = NULL) |>
     get_posterior_draws() |>
     ggplot(aes(x = draw)) +
     stat_halfeye()