Gaussian Processes

Smooth lines in fancy colours.

TipGoals of this lesson
  1. Let’s appreciate together the power of online community resources
  2. Gaussian processes are families of smooth functions we learn from data
  3. When used for prediction, a Gaussian process is both a “prior” and a “likelihood”

Background reading

Gaussian processes are surprisingly common, and there are lots of resources on the topic:

  1. The Stan manual has a chapter on it
  2. The Stan team gives lots of example models on Github which I adapted for this example.
  3. Michael Betancourt has an extremely detailed, very rigous tutorial on Gaussian Process
  4. Here’s a complete, worked analysis of human birthdays by world-class statisticians (in particular Andrew Gelman, Aki Vehtari and Daniel Simpson)
  5. Gaussian Processes are related to Generalized additive models (GAMs) and can be represented by a collection of basis functions. This is approximate but much much (!) faster. See this excellent tutorial by Aki Vehtari as well as the corresponding paper by Riutort -Mayol et al. (2023) also cited in the blog post.
  6. This blog applies Gaussian Processes to spatial count data
  7. Here is a very long and wonderfully detailed post describing a Gaussian Process approach to occupany modelling
  8. Another blog post on Gaussian Processes, Hidden Markov Models and more, very clear explanation.

Reorganizing the mite data

Let’s begin by (once again!) loading and reorganizing the mite data. This time we’ll also use mite.xy, which gives the coordinates of each one of the 70 samples. For the sack of this example, let’s focus on the PWIL species because it has a rather strong relationship with water and for which presence and absence are roughly balanced. This is just to make the illustration clear.

Loading models and data

library(rstan)
Loading required package: StanHeaders

rstan version 2.32.7 (Stan version 2.32.2)
For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)
For within-chain threading using `reduce_sum()` or `map_rect()` Stan functions,
change `threads_per_chain` option:
rstan_options(threads_per_chain = 1)
rstan_options("auto_write" = TRUE)
options(mc.cores = parallel::detectCores())

# mite data
data(mite, package = "vegan")
data(mite.env, package = "vegan")

## ALSO: the spatial data
data(mite.xy, package = "vegan")
pwil_data <- data.frame(plot_id = 1:nrow(mite),
                        WatrCont = mite.env$WatrCont, 
                        abd = mite$PWIL,
                        presabs = ifelse(mite$PWIL>0, 1, 0),
                        water = (mite.env$WatrCont - mean(mite.env$WatrCont))/100)

Let’s focus on a species with a rather strong relationship with water and for which presence and absence are roughly balanced. This is just to make the example clear.

# Plot data
plot(pwil_data$water, 
     pwil_data$presabs,
     pch = 19, 
     cex = 0.6, 
     col = adjustcolor("steelblue", 0.5),
     xlab = "Water (centered)",
     ylab = "Presence / Absence")

# Logistic regression
fit_pwil <- glm(presabs ~ water, 
                data = pwil_data, 
                family = binomial(link = "logit"))

# Draw prediction line
x_seq_pwil <- seq(min(pwil_data$water),
                  max(pwil_data$water), 
                  length.out = 100)

lines(x_seq_pwil, 
      predict(fit_pwil, 
              newdata = data.frame(water = x_seq_pwil),
              type = "response"), 
      col = "steelblue", 
      lwd = 2)

Probability of occurrance of one mite species, as a fuction of water content of the soil
# add the spatial coordinates:
pwil_spatial <- cbind(pwil_data, mite.xy)

pa_cols <- palette.colors(2, "Dark2")

plot(pwil_spatial$x, pwil_spatial$y,
     col  = pa_cols[pwil_spatial$presabs + 1],
     bg   = pa_cols[pwil_spatial$presabs + 1],
     pch  = 21, cex = 1.5,
     xlab = "x", ylab = "y",
     asp  = 1)

legend("topright", legend = c("Absent", "Present"),
       pch = 21, pt.bg = pa_cols, col = pa_cols)

Presence-absence data for mite species “PWIL”, at the spatial location of each point.

In this exercise, we will use a Gaussian process from two different angles:

  1. To build a nonlinear function of one variable to study how it relates to the abundance of one species
  2. To build a model characterizing spatial autocorrelation in the distribution of one species

Smooth function of one variable

Write the model

\[ \begin{align} \mathsf{Pr}(y_i = 1) &\sim \mathsf{Bernoulli}(p_i)\\ \mathsf{logit}(p_i) &= a + f_i\\ f_i &\sim \mathsf{multivariate\ normal}(0, K(x | \theta)) \\ K(x | \alpha, \rho, \sigma)_{i, j} &= \alpha^2 \exp \left( - \dfrac{1}{2 \rho^2} \sum_{d=1}^D (x_{i,d} - x_{j,d})^2 \right) + \delta_{i, j} \sigma^2, \end{align} \]

That equation above presents the general notation for a model with \(D\) dimensions that relies on the exponentiated quadratic function to account for the spatial constraints. In this first case study, we will focus on the univariate version of this equation.

\[ \begin{align} \mathsf{Pr}(y_i = 1) &\sim \mathsf{Bernoulli}(p_i)\\ \mathsf{logit}(p_i) &= a + f_i\\ f_i &\sim \mathsf{Multivariate\ Normal}(0, K(x | \theta)) \\ K(x | \alpha, \rho, \sigma)_{i, j} &= \alpha^2 e^{ \frac{-(\text{water}_i - \text{water}_j)^2}{2 \rho^2}} + \delta_{i, j} \sigma^2 \\ \rho &\sim \mathsf{Inverse\ Gamma}(5, 14) \\ \alpha &\sim \mathsf{Normal}(0, .8) \\ a &\sim \mathsf{Normal}(0, .2) \\ \end{align} \]

Here’s an interpretation of the parameters of this model:

  • \(\alpha^2\) is the maximum covariance between two points
  • \(\rho\) tells us how quickly that covariance goes down as two samples become more different in their water amount
  • \(\delta_{i, j} \sigma^2\) adds the variances along the diagonal

See the explanation of this function in the Stan User’s guide

Simulate to understand it

Here is the Stan code that replicates the mathematical model above.

gp_example_sim <- stan_model(file = "topics/04_gp/gp_example_sim.stan")

gp_example_sim
S4 class stanmodel 'anon_model' coded as follows:
// Fit the hyperparameters of a latent-variable Gaussian process with an
// exponentiated quadratic kernel and a Bernoulli likelihood
// This code is from https://github.com/stan-dev/example-models/blob/master/misc/gaussian-process/gp-fit-logit.stan
data {
  int<lower=1> N;
  array[N] real x;
}
transformed data {
  real delta = 1e-9;
}
parameters {
  real<lower=0> rho;
  real<lower=0> alpha;
  real a;
  vector[N] eta;
  vector[N] y;
}
transformed parameters {
  vector[N] f;
  {
    matrix[N, N] L_K;
    matrix[N, N] K = gp_exp_quad_cov(x, alpha, rho);

    // diagonal elements
    for (n in 1 : N) {
      K[n, n] = K[n, n] + delta;
    }

    L_K = cholesky_decompose(K);
    f = L_K * eta;
  }
}
model {
  rho ~ inv_gamma(5, 14);
  alpha ~ normal(0, .8);
  a ~ normal(0, .2);
  eta ~ std_normal();
} 
gp_example_sim_samples <- sampling(gp_example_sim,
                                   data = list(N = 20,x = seq(from = -3,
                                                              to = 5, 
                                                              length.out = 20)),
                                   chains = 4,
                                   iter = 2000)
Warning: There were 274 divergent transitions after warmup. See
https://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.
Warning: There were 3248 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
https://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
Warning: Examine the pairs() plot to diagnose sampling problems
Warning: The largest R-hat is 3.07, indicating chains have not mixed.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#r-hat
Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#bulk-ess
Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
https://mc-stan.org/misc/warnings.html#tail-ess
x_pred_prior <- seq(from = -3, to = 5, length.out = 20)
draws_f_prior <- rstan::extract(gp_example_sim_samples, "f")$f
draws_a_prior <- as.vector(rstan::extract(gp_example_sim_samples, "a")$a)
draw_idx_prior <- sample(length(draws_a_prior), 45)

plot(0,
     0,
     type = "n",
     xlim = range(x_pred_prior), 
     ylim = c(0, 1),
     xlab = "Water",
     ylab = "Pr(presence)")

for (k in draw_idx_prior) {
  lines(x_pred_prior, 
        plogis(draws_f_prior[k, ] + draws_a_prior[k]),
        col = adjustcolor("steelblue", 0.4))
}

Express that model in code

With a working simulation, we can now adapt the model to handle real data.

gp_example_pred <- stan_model(
  file = "topics/04_gp/gp_example_pred.stan")

gp_example_pred
S4 class stanmodel 'anon_model' coded as follows:
// Fit the hyperparameters of a latent-variable Gaussian process with an
// exponentiated quadratic kernel and a Bernoulli likelihood
// This code is from https://github.com/stan-dev/example-models/blob/master/misc/gaussian-process/gp-fit-logit.stan
data {
  int<lower=1> Nobs;
  int<lower=1> N;
  array[N] real x;
  array[Nobs] int<lower=0, upper=1> z;
}
transformed data {
  real delta = 1e-9;
}
parameters {
  real<lower=0> rho;
  real<lower=0> alpha;
  real a;
  vector[N] eta;
}
transformed parameters {
  vector[N] f;
  {
    matrix[N, N] L_K;
    matrix[N, N] K = gp_exp_quad_cov(x, alpha, rho);

    // diagonal elements
    for (n in 1 : N) {
      K[n, n] = K[n, n] + delta;
    }

    L_K = cholesky_decompose(K);
    f = L_K * eta;
  }
}
model {
  rho ~ inv_gamma(5, 14);
  alpha ~ normal(0, .8);
  a ~ normal(0, .2);
  eta ~ std_normal();

  z ~ bernoulli_logit(a + f[1:Nobs]);
} 

We need to generate data for making predictions! I’ll create a new vector of observations called new_x that cover the range of the water variable in our dataset.

# sample N values on the range of x
new_x <- seq(from = -3, to = 5, length.out = 15)

gp_data_list <- list(N = length(pwil_spatial$presabs) + length(new_x),
                     Nobs = length(pwil_spatial$presabs),
                     x = c(pwil_spatial$water, new_x),
                     z = pwil_spatial$presabs)

# put them on the data frame
gp_example_pwil_samp <- sampling(gp_example_pred,
                                 data = gp_data_list,
                                 chains = 4, 
                                 iter = 2000, 
                                 refresh = 1000)
Nobs_pwil <- length(pwil_spatial$presabs)
draws_f_pwil <- rstan::extract(gp_example_pwil_samp, "f")$f
draws_a_pwil <- as.vector(rstan::extract(gp_example_pwil_samp, "a")$a)
draws_pred <- plogis(draws_f_pwil[, (Nobs_pwil + 1):ncol(draws_f_pwil)] +
                     draws_a_pwil)

q_pred <- apply(draws_pred,
                2,
                quantile,
                probs = c(0.025, 0.25, 0.5, 0.75, 0.975))

plot(pwil_spatial$water, 
     jitter(pwil_spatial$presabs, factor = 0.05),
     pch = 19,
     cex = 0.5, 
     col = adjustcolor("black", 0.4),
     xlab = "Water (centered)",
     ylab = "Pr(presence)",
     ylim = c(-0.05, 1.05))

polygon(c(new_x, rev(new_x)), 
        c(q_pred[1, ], rev(q_pred[5, ])),
        col = adjustcolor("red3", 0.15),
        border = NA)

polygon(c(new_x, rev(new_x)), 
        c(q_pred[2, ], rev(q_pred[4, ])),
        col = adjustcolor("red3", 0.30), 
        border = NA)

lines(new_x, 
      q_pred[3, ], 
      col = "red3", 
      lwd = 2)

A Gaussian Process estimates a distribution of smooth functions to a dataset. Here we’re using it to estimate the effect of water amount on the occurence of a mite.

We can also pull out some specific functions. What I want you to see here is that there are MANY curvy lines that are consistent with this model.

draw_idx_63 <- sample(length(draws_a_pwil), 63)

plot(0,
     0, 
     type = "n",
     xlim = range(new_x), 
     ylim = c(0, 1),
     xlab = "Water (centered)",
     ylab = "Pr(presence)")

for (k in draw_idx_63) {
  prob_k <- plogis(draws_f_pwil[k, (Nobs_pwil + 1):ncol(draws_f_pwil)] +
                   draws_a_pwil[k])
  
  lines(new_x, prob_k, col = adjustcolor("black", 0.3))
}

Spatial predictions

To make a prediction of a function on one X variable, we needed a sequence of points to predict along.

To make spatial predictions, we need a grid of points to predict along.

grid_points <- expand.grid(x = seq(min(mite.xy$x), max(mite.xy$x), by = 0.25),
                           y = seq(min(mite.xy$y), max(mite.xy$y), by = 0.25))

plot(grid_points$x,
     grid_points$y,
     pch = 19, 
     cex = 0.4,
     xlab = "x",
     ylab = "y",
     asp = 1,
     xlim = range(grid_points$x) + c(-0.25, 0.25))

Mathematically, the model is slightly different than for the previous case because the exponentiated quadratic function now needs to account for two dimensions. So, our model looks like this:

\[ \begin{align} \mathsf{Pr}(y_i = 1) &\sim \mathsf{Bernoulli}(p_i)\\ \mathsf{logit}(p_i) &= a + f_i\\ f_i &\sim \mathsf{multivariate\ normal}(0, K(x | \theta)) \\ K(x | \alpha, \rho, \sigma)_{i, j} &= \alpha^2 \exp \left( - \dfrac{1}{2 \rho^2} \sum_{d=1}^2 (x_{i,d} - x_{j,d})^2 \right) + \delta_{i, j} \sigma^2, \end{align} \]

However, from a coding standpoint, other than a change in the data {} block, the Stan code is unchanged! This is because the gp_exp_quad_cov Stan function was designed to work for any dimensions.

Prior predictive simulations

gp_example_2D_prior <- stan_model(
  file = "topics/04_gp/gp_example_2D_prior.stan")

gp_example_2D_prior
S4 class stanmodel 'anon_model' coded as follows:
// Fit the hyperparameters of a latent-variable Gaussian process with an
// exponentiated quadratic kernel and a Bernoulli likelihood
// This code is from https://github.com/stan-dev/example-models/blob/master/misc/gaussian-process/gp-fit-logit.stan
data {
  // int<lower=1> Nobs;
  int<lower=1> N;
  array[N] vector[2] x;
  real rho_a;
  real rho_b;
  // array[Nobs] int<lower=0, upper=1> z;
}
transformed data {
  real delta = 1e-9;
}
parameters {
  real<lower=0> rho;
  real<lower=0> alpha;
  real a;
  vector[N] eta;
}
transformed parameters {
  vector[N] f;
  {
    matrix[N, N] L_K;
    matrix[N, N] K = gp_exp_quad_cov(x, alpha, rho);

    // diagonal elements
    for (n in 1 : N) {
      K[n, n] = K[n, n] + delta;
    }

    L_K = cholesky_decompose(K);
    f = L_K * eta;
  }
}
model {
  rho ~ inv_gamma(rho_a, rho_b);
  alpha ~ normal(0, .8);
  a ~ normal(0, .2);
  eta ~ std_normal();
  // z ~ bernoulli_logit(a + f[1:Nobs]);
} 
gp_example_2D_prior_samp <- sampling(gp_example_2D_prior,
                                     data = list(N = nrow(grid_points),
                                                 x = grid_points,
                                                 rho_a = 5, 
                                                 rho_b = 14),
                                     chains = 4, 
                                     iter = 2000,
                                     refresh = 1000)

visualize the prior

## extract the predictors
draws_f2d_prior <- rstan::extract(gp_example_2D_prior_samp, "f")$f
draws_a2d_prior <- as.vector(rstan::extract(gp_example_2D_prior_samp, "a")$a)
draw_idx_2d_prior <- sample(length(draws_a2d_prior), 6)

x_uniq <- sort(unique(grid_points$x))
y_uniq <- sort(unique(grid_points$y))

par(mfrow = c(2, 3), mar = c(1, 1, 1.5, 0.5))

for (k in draw_idx_2d_prior) {
  pa_k  <- plogis(draws_f2d_prior[k, ] + draws_a2d_prior[k])
  z_mat <- matrix(pa_k, nrow = length(x_uniq), ncol = length(y_uniq))
  image(x_uniq, y_uniq, z_mat,
        col  = hcl.colors(100, "Inferno"),
        xlab = "", ylab = "", main = "Pr(y=1)", asp = 1)
}

WarningCAUTION: Slow

The model below is the slowest model we’ve seen so far and takes about 5 minutes to run on my laptop.

gp_example_pred_2D <- stan_model(
  file = "topics/04_gp/gp_example_pred_2D.stan")

gp_example_pred_2D
S4 class stanmodel 'anon_model' coded as follows:
// Fit the hyperparameters of a latent-variable Gaussian process with an
// exponentiated quadratic kernel and a Bernoulli likelihood
// This code is from https://github.com/stan-dev/example-models/blob/master/misc/gaussian-process/gp-fit-logit.stan
data {
  int<lower=1> Nobs;
  int<lower=1> N;
  array[N] vector[2] x;
  array[Nobs] int<lower=0, upper=1> z;
}
transformed data {
  real delta = 1e-9;
}
parameters {
  real<lower=0> rho;
  real<lower=0> alpha;
  real a;
  vector[N] eta;
}
transformed parameters {
  vector[N] f;
  {
    matrix[N, N] L_K;
    matrix[N, N] K = gp_exp_quad_cov(x, alpha, rho);

    // diagonal elements
    for (n in 1 : N) {
      K[n, n] = K[n, n] + delta;
    }

    L_K = cholesky_decompose(K);
    f = L_K * eta;
  }
}
model {
  rho ~ inv_gamma(5, 14);
  alpha ~ normal(0, .8);
  a ~ normal(0, .2);
  eta ~ std_normal();

  z ~ bernoulli_logit(a + f[1:Nobs]);
} 

plot the effect in space:

## sample the model
gp_example_2D_samp <- sampling(gp_example_pred_2D,
                               data = list(N = length(pwil_spatial$presabs) + nrow(grid_points),
                                           Nobs = length(pwil_spatial$presabs),
                                           x = rbind(pwil_spatial[c("x", "y")], grid_points),
                                           z = pwil_spatial$presabs),
                               chains = 4, 
                               iter = 1000)


# gp_example_2D_samp$save_object("topics/04_gp/gp_example_2D_samp_pwil.rds")
# gp_example_2D_samp_pwil <- read_rds("topics/04_gp/gp_example_2D_samp_pwil.rds")

## extract the predictors
Nobs_2d <- length(pwil_spatial$presabs)
draws_f2d <- rstan::extract(gp_example_2D_samp, "f")$f
draws_a2d <- as.vector(rstan::extract(gp_example_2D_samp, "a")$a)
draws_pred_2d <- plogis(draws_f2d[, (Nobs_2d + 1):ncol(draws_f2d)] + draws_a2d)
pa_median_2d <- apply(draws_pred_2d, 2, median)

z_mat_2d <- matrix(pa_median_2d, nrow = length(x_uniq), ncol = length(y_uniq))

image(x_uniq, y_uniq, z_mat_2d,
      col  = hcl.colors(100, "Inferno"),
      xlab = "x", ylab = "y", main = "Pr(y=1) — posterior median", asp = 1)
# overlay observed points
pt_col <- ifelse(pwil_spatial$presabs == 1, "white", "lightblue")
points(pwil_spatial$x, pwil_spatial$y,
       pch = 21, cex = 1.5, bg = pt_col, col = "grey40")

Extensions:

Add water to the model. Does the spatial effect disappear, increase, or stay kind of the same?

Next step: try to model water curve for more than one species. Would it be possible to make the species rho parameters hierarchical?