DynForest for survival outcome

Overview

DynForest is a parallelized R package, including two main functions DynForest() and predict() to respectively build the random forest, and predict the outcome on new subjects. Longitudinal predictors are modeled through the random forest using lcmm R package (see Proust-Lima et al., 2017). Although DynForest was developed to include longitudinal predictors, it can also be used with only time-fixed predictors such as conventional random forest.

Illustration on pbc2 data with survival outcome

We use DynForest on the pbc2 dataset (see Murtaugh et al., 1994) to illustrate our methodology. Data come from the clinical trial conducted by the Mayo Clinic between 1974 and 1984. For the illustration, we consider a subsample of the original dataset resulting to 312 patients and 7 predictors. Among these predictors, the level of serum bilirubin (serBilir), aspartate aminotransferase (SGOT), albumin and alkaline were measured at inclusion and during the follow-up leading to a total of 1945 observations. Sex, age and the drug treatment were collected at the enrollment. During the follow-up, 140 patients died, 29 patients were transplanted and 143 patients were alive. The time of last follow-up (alive or any event) was considered as the event time. We deal the transplantation as a competing event in this illustration. We aim to predict the death on patients suffering from primary billiary cholangitis (PBC) using clinical and socio-demographic predictors.

Build predictors objects

To begin, we load DynForest package and pbc2 data and we split the subjects into two datasets: (i) to train the random forest using \(2/3\) of patients; (ii) to predict on the other \(1/3\) patients.

# load package
library(DynForest)

# load data
data(pbc2)

# Split the data for training and prediction steps
set.seed(1234)
id <- unique(pbc2$id)
id_sample <- sample(id, length(id)*2/3)
id_row <- which(pbc2$id%in%id_sample)
pbc2_train <- pbc2[id_row,]
pbc2_pred <- pbc2[-id_row,]

Then, we build the dataframe for the longitudinal predictors including:

These predictors were previously normalized to satisfy the Gaussian assumption for the linear mixed models. We also build the dataframe with the time-fixed predictors including:

Be careful to define the right nature of the predictors, in particular for drug and sex as discrete predictors using as.factor() function. You can also verify the nature of the predictors using str() function.

# Build predictors objects
timeData_train <- pbc2_train[,c("id","time",
                                "serBilir","SGOT",
                                "albumin","alkaline")]
fixedData_train <- unique(pbc2_train[,c("id","age","drug","sex")])

Define association for each longitudinal predictor

The first step is to build the random forest using DynForest() function. We need to define the association of each longitudinal predictors through a list containing the fixed and random formula (defined with lcmm package). To get more flexible association, splines are allowed in formula using splines package.

# Create object with longitudinal association for each predictor
timeVarModel <- list(serBilir = list(fixed = serBilir ~ time,
                                     random = ~ time),
                     SGOT = list(fixed = SGOT ~ time + I(time^2),
                                 random = ~ time + I(time^2)),
                     albumin = list(fixed = albumin ~ time,
                                    random = ~ time),
                     alkaline = list(fixed = alkaline ~ time,
                                     random = ~ time))

Build outcome object

The outcome of interest is defined using a list containing:

For this illustration, we choose surv as the type of outcome. Y is defined with:

# Build outcome object
Y <- list(type = "surv",
          Y = unique(pbc2_train[,c("id","years","event")]))

Build the random survival forest

To execute DynForest() function, several arguments are mandatory:

In a survival context with multiple events, it is also necessary to specify the event of interest with the argument cause.

# Build the random forest
res_dyn <- DynForest(timeData = timeData_train, fixedData = fixedData_train,
                     timeVar = "time", idVar = "id", timeVarModel = timeVarModel,
                     ntree = 200, nodesize = 5, minsplit = 5, cause = 2,
                     Y = Y, seed = 1234)

Many others arguments could be used in DynForest() function. Among them, ntree, mtry, nodesize and minsplit are hyperparameters which have been arbitrarily chosen for the illustration. However, these hyperparameters should be tuned to improve the prediction performance of the random forest. seed argument could be used to reproduce the same random forest and ncores to define the number of CPU cores used by the function. By default, ncores is fixed to the total number of cores minus 1.

DynForest() function returns an object of class DynForest. Among the returned objects, $rf provides all information about the trees, especially $V_split element on the split. The split detail for the tree 1 could be given with the following code:

head(res_dyn$rf[,1]$V_split)

tail(res_dyn$rf[,1]$V_split)

The table is sorted by the node/leaf identifier (num_noeud column) with each row represents a node/leaf. Each column provides information about the splits:

For instance for the interpretation of the node split, the subjects were split at node 1 using the first random-effect (\(\texttt{var_summary} = 1\)) of the third Curve predictor (\(\texttt{var_split} = 3\)) with \(\texttt{threshold} = -0.2199\). The predictor name can be found using the predictor number in timeData and fixedData datasets. Therefore, the subjects at node 1 with albumin values below to -0.2199 drop in the node 2, otherwise in the node 3. Another example with the leaves, 4 subjects are included in the node 192, and among them 2 subjects has the event of interest. Estimated cumulative incidence function (CIF) are provided using $Y_pred element of $rf. $Y_pred is a list containing the CIF for every causes for each time of interest. For instance, the CIF of the cause of interest for leaf 192 can be displayed using the following code:

# Display CIF for cause of interest
plot(res_dyn$rf[,1]$Y_pred[[192]]$`2`, type = "l", col = "red", 
     xlab = "Years", ylab = "CIF", ylim = c(0,1))

Out-Of-Bag error

The Out-Of-Bag error (OOB) is computed using compute_OOBerror() function. In addition to DynForest object, compute_OOBerror() returns the OOB error by individual ($xerror) or by tree ($oob.err). The overall OOB error for the random forest is obtained by averaging the OOB error (by individual or tree), also given using the summary() function. In a survival context, the OOB error is evaluated using Integrated Brier Score from 0 to the maximum time event. The range time could be modified using IBS.min and IBS.max arguments to define the minimum and maximum, respectively. To improve the prediction ability of the random forest, we want to minimize the OOB error. This could be done by tuning the hyperparameters.

# Compute OOB error
res_dyn_OOB <- compute_OOBerror(DynForest_obj = res_dyn)
# Get summary
summary(res_dyn_OOB)

Predict the risk on new subjects

In this step, we want to predict the CIF for new subjects using the estimated random forest. Dynamic predictions could be computed by fixing a landmark time where the longitudinal data will be censored at this time. For the illustration, we only select the subjects still at risk at 4 years. Then, we build the data for those subjects and we predict the cumulative incidence function (CIF) using predict() function as follows:

# Build data for prediction
id_pred <- unique(pbc2_pred$id[which(pbc2_pred$years>4)])
pbc2_pred <- pbc2_pred[which(pbc2_pred$id%in%id_pred),]
timeData_pred <- pbc2_pred[,c("id", "time", "serBilir", "SGOT", "albumin", "alkaline")]
fixedData_pred <- unique(pbc2_pred[,c("id","age","drug","sex")])

# Prediction step
pred_dyn <- predict(object = res_dyn, 
                    timeData = timeData_pred, fixedData = fixedData_pred,
                    idVar = "id", timeVar = "time",
                    t0 = 4)

timeData, fixedData, idVar and timeVar are the same arguments as detailed in the DynForest() function. In addition to these arguments, we also have:

predict() function provides multiple elements:

plot_CIF() function allows to display the CIF of the event of interest for given subjects. For instance, we compare the CIF from the landmark time for the subjects 102 and 260.

plot_CIF(DynForestPred_obj = pred_dyn,
         id = c(102, 260))

Explore the most predictive variables

VIMP

The main objective of the random forest is to predict an outcome. But sometimes, we can also be interested to explore the most predictive variables. The VIMP statistic can be computed using compute_VIMP() function. In addition to DynForest object, this function also returns the VIMP statistic for each predictor with $Importance argument. These results could also be displayed using plot_VIMP() function.

# Compute VIMP statistic
res_dyn_VIMP <- compute_VIMP(DynForest_obj = res_dyn_OOB)

# Plot VIMP statistic
plot_VIMP(res_dyn_VIMP)

We found that the most predictive variables are serBilir and albumin with the largest VIMP.

gVIMP

To evaluate the VIMP statistic for a group of several predictors, the gVIMP statistic can be computed through the compute_gVIMP() function. This function has the group argument to define the group of predictors through a list. For instance, with two groups of predictors (named group1 and *group2), the gVIMP statistic is computed using the following code:

# Define groups
group <- list(group1 = c("serBilir","SGOT"),
              group2 = c("albumin","alkaline"))

# Compute gVIMP statistic
res_dyn_gVIMP <- compute_gVIMP(DynForest_obj = res_dyn_OOB,
                               group = group)

# Plot gVIMP statistic
plot_gVIMP(res_dyn_gVIMP)

Similar as VIMP statistic, the gVIMP results could also be displayed using plot_gVIMP() function. The figure indicates most predictive ability with the group1. We also observe that the gVIMP for group2 is lower than the sum of the VIMP of the two predictors from this group (albumin and alkaline). This figure shows how it could be relevant to compute the gVIMP statistic to consider the predictive ability of a group.

To compute the gVIMP statistic, the groups can be defined regardless of the number of predictors. However, the comparition between the groups could be harder when their size are differents.

Average minimal depth

To go further in the understanding of the tree building process, the var_depth() function extracts usefull information about the average minimal depth by feature ($min_depth), a table with the minimal depth for each feature in row and each tree in column ($var_node_depth), a table with the number of times that the feature is used for splitting for each feature in row and each tree in column ($var_count). From the var_depth() object, plot_mindepth() function allows to plot the distribution of the average minimal depth across the trees. plot_level argument defines how the average minimal depth is plotted, by predictor or feature.

# Extract tree building information
depth_dyn <- var_depth(res_dyn)

# Plot average minimal depth by predictor
plot_mindepth(var_depth_obj = depth_dyn,
              plot_level = "predictor")

# Plot average minimal depth by feature
plot_mindepth(var_depth_obj = depth_dyn,
              plot_level = "feature")

The distribution of the minimal depth level are displayed by predictor and feature. Minimal depth level should always be interpreted with the number of trees where the predictor/feature is found. For instance, we observe that serBilir and albumin have the lowest minimal depth, indicating these predictors are used to split the subjects at early stage in 200 out of 200 trees, i.e 100%. Indeed, in particular scenarios, we might observe lower minimal depth level with few trees used. That situation can occured due to the randomness of the method, especially when mtry hyperparameter is close to 1 or ntree is not enough large.

The minimal depth level by feature provides more advanced details about the tree building process. For instance, we can see that the first random-effect (indicating by bi0 on the graph) for serBilir and albumin are the earliest features used on 200 and 199 out of 200 trees, respectively.

Knowing that the number of trees where the predictor/feature is found depends on mtry hyperparameter, the minimal depth could also be computed on the random forest with mtry chosen at his maximum.

Guidelines to tune the hyperparameters

The predictive performance of the random forest strongly depends on the hyperparameters mtry, nodesize and minsplit, and should therefore be chosen thoroughly. nodesize and minsplit hyperparameters control the tree depth, and we want trees deep enough to ensure that the predictions are not biased. By default in DynForest() function, we fixed \(\texttt{nodesize} = 1\) and \(\texttt{minsplit} = 2\), being the minimum. However, with a large number of individuals, the depth tree could be slighty decreased to reduce the computation time.

mtry hyperparameter defines the number of predictors randomly drawn at each node. It controls the correlation between the trees. By default, we chose mtry equals to the square root of the number of predictors. However, this hyperparameter should be carefully tuned with the possible values between 1 and the number of predictors. Indeed, the predictive performance of the random forest is related to this hyperparameter. In the illustration, we tuned mtry for every possible values (1 to 7).

err.OOB <- vector("numeric", 7)

for (i in 1:7){
  
  set.seed(i)
  
  res_dyn_mtry <- DynForest(timeData = timeData_train, fixedData = fixedData_train,
                            timeVar = "time", idVar = "id", 
                            timeVarModel = timeVarModel, Y = Y,
                            ntree = 200, mtry = i, nodesize = 2, minsplit = 3,
                            cause = 2)
  
  res_dyn_mtry_OOB <- compute_OOBerror(DynForest_obj = res_dyn_mtry)
  
  err.OOB[i] <- mean(res_dyn_mtry_OOB$xerror, na.rm = T)
  
}
library(ggplot2)
ggplot(data.frame(mtry = seq(7), OOB.error = err.OOB), aes(x = mtry, y = OOB.error)) +
  geom_line(color = "red") +
  geom_point(color = "red", size = 1) +
  ylab("OOB error") +
  theme_bw()

The figure displays the evolution of the OOB error according to mtry hyperparameter. We can see on this figure large OOB error difference according to mtry hyperparameter. In particular, we observe the worst predictive performance for lower values, then simular results with higher values with an optimal value (i.e. with the lowest OOB error) where \(\texttt{mtry} = 7\). This graph reflects how it is crucial to carefully tuned this hyperparameter.