This packages allows you to to marginalize arbitrary prediction functions using Monte-Carlo integration. Since many prediction functions cannot be easily decomposed into a sum of low dimensional components marginalization can be helpful in making these functions interpretable.
marginalPrediction
does this computation and then evaluates the marginalized function at a set grid points, which can be uniformly created, subsampled from the training data, or explicitly specified via the points
argument.
The create of a uniform grid is handled by the uniformGrid
method. If uniform = FALSE
and the points
argument isn’t used to specify what points to evaluate, a sample of size n[1]
is taken from the data without replacement.
The function is integrated against a sample of size n[2]
taken without replacement from the data
. The argument int.points
can be used to override this (in which case you can specify n[2] = NA
). int.points
is a vector of integerish indices which specify rows of the data
to use instead.
library(mmpf)
library(randomForest)
## randomForest 4.6-12
## Type rfNews() to see new features/changes/bug fixes.
library(ggplot2)
##
## Attaching package: 'ggplot2'
## The following object is masked from 'package:randomForest':
##
## margin
library(reshape2)
data(swiss)
fit = randomForest(Fertility ~ ., swiss)
mp = marginalPrediction(swiss[, -1], "Education", c(10, nrow(swiss)), fit)
mp
## $prediction
## [1] 71.58794 71.67721 71.54158 66.14330 63.66906 63.08498 63.27220
## [8] 63.24072 62.79569 62.79569
##
## $points
## Education
## 1 1
## 2 7
## 3 13
## 4 18
## 5 24
## 6 30
## 7 36
## 8 41
## 9 47
## 10 53
ggplot(data.frame(mp), aes(Education, prediction)) + geom_point() + geom_line()
The output of marginalPrediction
is always a list with two elements, prediction
and points
.
By default the Monte-Carlo expectation is computed, which is set by the aggregate.fun
argument’s default value, the mean
function. Substituting, say, the median, would give a different output.
By passing the identity function to aggregate.fun
, which simply returns its input exactly, the integration points are returned directly so that the prediction
element of the return is a matrix of dimension n
. n
, although it is an argument, can be larger or smaller depending on the interaction between the input arguments n
and data
. For example if a uniform grid of size 10 is requested (via n[1]
) from a factor with only 5 levels, a uniform grid of size 5 is created. If vars
is a vector of length greater than 1, then n[1]
becomes the size of the Cartesian product of the grids created for each element of vars
, which can be at most n[1]^length(vars)
.
mp = marginalPrediction(swiss[, -1], "Education", c(10, 5), fit, aggregate.fun = identity)
mp
## $prediction
## [,1] [,2] [,3] [,4] [,5]
## [1,] 78.25082 73.66435 53.89589 68.25105 62.06470
## [2,] 79.59879 74.33015 53.58569 68.68094 61.67780
## [3,] 78.06703 74.26409 53.97597 69.01337 62.47315
## [4,] 71.08200 68.41759 47.92748 64.27462 57.57195
## [5,] 68.09495 65.46963 46.17074 61.91080 55.91810
## [6,] 67.74647 65.04020 45.06222 61.46220 54.73995
## [7,] 67.97800 65.27173 44.94812 61.73313 54.89971
## [8,] 67.97800 65.27173 44.80190 61.69183 54.76893
## [9,] 67.56380 64.87791 44.16665 61.05306 54.08931
## [10,] 67.56380 64.87791 44.16665 61.05306 54.08931
##
## $points
## Education
## 1 1
## 2 7
## 3 13
## 4 18
## 5 24
## 6 30
## 7 36
## 8 41
## 9 47
## 10 53
ggplot(melt(data.frame(mp), id.vars = "Education"), aes(Education, value, group = variable)) + geom_point() + geom_line()
predict.fun
specifies a prediction function to apply to the model
argument. This function must take two arguments, object
(where model
is inserted) and newdata
, which is a data.frame
to compute predictions on, which is generated internally and is controlled by the other arguments. This allows marginalPrediction
to handle cases in which predictions for a single data point are vector-valued. That is, classification tasks where probabilities are output, and multivariate regression and/or classification. In these cases aggregate.fun
is applied separately to each column of the prediction matrix. aggregate.fun
must take one argument x
, a vector output from predict.fun
and return a vector of no greater dimension than that of x
.
data(iris)
fit = randomForest(Species ~ ., iris)
mp = marginalPrediction(iris[, -ncol(iris)], "Petal.Width", c(10, 25), fit,
predict.fun = function(object, newdata) predict(object, newdata = newdata, type = "prob"))
mp
## $prediction
## setosa versicolor virginica
## [1,] 0.56784 0.29176 0.14040
## [2,] 0.56784 0.29176 0.14040
## [3,] 0.56288 0.29608 0.14104
## [4,] 0.10112 0.64792 0.25096
## [5,] 0.10008 0.64896 0.25096
## [6,] 0.10008 0.63512 0.26480
## [7,] 0.09872 0.49048 0.41080
## [8,] 0.09656 0.19808 0.70536
## [9,] 0.09656 0.19784 0.70560
## [10,] 0.09656 0.19784 0.70560
##
## $points
## Petal.Width
## 1 0.1000000
## 2 0.3666667
## 3 0.6333333
## 4 0.9000000
## 5 1.1666667
## 6 1.4333333
## 7 1.7000000
## 8 1.9666667
## 9 2.2333333
## 10 2.5000000
colnames(mp$prediction) = levels(iris$Species)
names(mp)[1] = ""
plt = melt(data.frame(mp), id.vars = "Petal.Width", variable.name = "class",
value.name = "prob")
ggplot(plt, aes(Petal.Width, prob, color = class)) + geom_line() + geom_point()
As mentioned before, vars
can include multiple variables.
mp = marginalPrediction(iris[, -ncol(iris)], c("Petal.Width", "Petal.Length"), c(10, 25), fit,
predict.fun = function(object, newdata) predict(object, newdata = newdata, type = "prob"))
mp
## $prediction
## setosa versicolor virginica
## [1,] 0.95496 0.04232 0.00272
## [2,] 0.95496 0.04232 0.00272
## [3,] 0.94872 0.04856 0.00272
## [4,] 0.46784 0.51368 0.01848
## [5,] 0.46632 0.51520 0.01848
## [6,] 0.46632 0.51096 0.02272
## [7,] 0.45832 0.37224 0.16944
## [8,] 0.44104 0.23864 0.32032
## [9,] 0.44104 0.23864 0.32032
## [10,] 0.44104 0.23864 0.32032
## [11,] 0.95496 0.04232 0.00272
## [12,] 0.95496 0.04232 0.00272
## [13,] 0.94872 0.04856 0.00272
## [14,] 0.46784 0.51368 0.01848
## [15,] 0.46632 0.51520 0.01848
## [16,] 0.46632 0.51096 0.02272
## [17,] 0.45832 0.37224 0.16944
## [18,] 0.44104 0.23864 0.32032
## [19,] 0.44104 0.23864 0.32032
## [20,] 0.44104 0.23864 0.32032
## [21,] 0.95352 0.04376 0.00272
## [22,] 0.95352 0.04376 0.00272
## [23,] 0.94728 0.05000 0.00272
## [24,] 0.46640 0.51512 0.01848
## [25,] 0.46488 0.51664 0.01848
## [26,] 0.46488 0.51240 0.02272
## [27,] 0.45688 0.37296 0.17016
## [28,] 0.43960 0.23936 0.32104
## [29,] 0.43960 0.23936 0.32104
## [30,] 0.43960 0.23936 0.32104
## [31,] 0.49896 0.48600 0.01504
## [32,] 0.49896 0.48600 0.01504
## [33,] 0.49272 0.49224 0.01504
## [34,] 0.01184 0.95736 0.03080
## [35,] 0.01032 0.95888 0.03080
## [36,] 0.01032 0.95136 0.03832
## [37,] 0.01024 0.66768 0.32208
## [38,] 0.00952 0.41848 0.57200
## [39,] 0.00952 0.41776 0.57272
## [40,] 0.00952 0.41776 0.57272
## [41,] 0.49624 0.48872 0.01504
## [42,] 0.49624 0.48872 0.01504
## [43,] 0.49000 0.49496 0.01504
## [44,] 0.00912 0.96008 0.03080
## [45,] 0.00760 0.96160 0.03080
## [46,] 0.00760 0.95408 0.03832
## [47,] 0.00752 0.67040 0.32208
## [48,] 0.00680 0.42120 0.57200
## [49,] 0.00680 0.42048 0.57272
## [50,] 0.00680 0.42048 0.57272
## [51,] 0.49624 0.48432 0.01944
## [52,] 0.49624 0.48432 0.01944
## [53,] 0.49000 0.49056 0.01944
## [54,] 0.00912 0.95336 0.03752
## [55,] 0.00760 0.95488 0.03752
## [56,] 0.00760 0.94712 0.04528
## [57,] 0.00752 0.66304 0.32944
## [58,] 0.00680 0.41472 0.57848
## [59,] 0.00680 0.41400 0.57920
## [60,] 0.00680 0.41400 0.57920
## [61,] 0.46392 0.41336 0.12272
## [62,] 0.46392 0.41336 0.12272
## [63,] 0.45904 0.41824 0.12272
## [64,] 0.00880 0.77112 0.22008
## [65,] 0.00728 0.77264 0.22008
## [66,] 0.00728 0.76496 0.22776
## [67,] 0.00720 0.57880 0.41400
## [68,] 0.00648 0.12976 0.86376
## [69,] 0.00648 0.12904 0.86448
## [70,] 0.00648 0.12904 0.86448
## [71,] 0.45704 0.16448 0.37848
## [72,] 0.45704 0.16448 0.37848
## [73,] 0.45488 0.16544 0.37968
## [74,] 0.00872 0.30704 0.68424
## [75,] 0.00720 0.30856 0.68424
## [76,] 0.00720 0.27048 0.72232
## [77,] 0.00712 0.28112 0.71176
## [78,] 0.00648 0.04992 0.94360
## [79,] 0.00648 0.04816 0.94536
## [80,] 0.00648 0.04816 0.94536
## [81,] 0.45704 0.16448 0.37848
## [82,] 0.45704 0.16448 0.37848
## [83,] 0.45488 0.16544 0.37968
## [84,] 0.00872 0.30704 0.68424
## [85,] 0.00720 0.30856 0.68424
## [86,] 0.00720 0.27048 0.72232
## [87,] 0.00712 0.28112 0.71176
## [88,] 0.00648 0.04992 0.94360
## [89,] 0.00648 0.04816 0.94536
## [90,] 0.00648 0.04816 0.94536
## [91,] 0.45704 0.16448 0.37848
## [92,] 0.45704 0.16448 0.37848
## [93,] 0.45488 0.16544 0.37968
## [94,] 0.00872 0.30704 0.68424
## [95,] 0.00720 0.30856 0.68424
## [96,] 0.00720 0.27048 0.72232
## [97,] 0.00712 0.28112 0.71176
## [98,] 0.00648 0.04992 0.94360
## [99,] 0.00648 0.04816 0.94536
## [100,] 0.00648 0.04816 0.94536
##
## $points
## Petal.Width Petal.Length
## 1 0.1000000 1.000000
## 2 0.3666667 1.000000
## 3 0.6333333 1.000000
## 4 0.9000000 1.000000
## 5 1.1666667 1.000000
## 6 1.4333333 1.000000
## 7 1.7000000 1.000000
## 8 1.9666667 1.000000
## 9 2.2333333 1.000000
## 10 2.5000000 1.000000
## 11 0.1000000 1.655556
## 12 0.3666667 1.655556
## 13 0.6333333 1.655556
## 14 0.9000000 1.655556
## 15 1.1666667 1.655556
## 16 1.4333333 1.655556
## 17 1.7000000 1.655556
## 18 1.9666667 1.655556
## 19 2.2333333 1.655556
## 20 2.5000000 1.655556
## 21 0.1000000 2.311111
## 22 0.3666667 2.311111
## 23 0.6333333 2.311111
## 24 0.9000000 2.311111
## 25 1.1666667 2.311111
## 26 1.4333333 2.311111
## 27 1.7000000 2.311111
## 28 1.9666667 2.311111
## 29 2.2333333 2.311111
## 30 2.5000000 2.311111
## 31 0.1000000 2.966667
## 32 0.3666667 2.966667
## 33 0.6333333 2.966667
## 34 0.9000000 2.966667
## 35 1.1666667 2.966667
## 36 1.4333333 2.966667
## 37 1.7000000 2.966667
## 38 1.9666667 2.966667
## 39 2.2333333 2.966667
## 40 2.5000000 2.966667
## 41 0.1000000 3.622222
## 42 0.3666667 3.622222
## 43 0.6333333 3.622222
## 44 0.9000000 3.622222
## 45 1.1666667 3.622222
## 46 1.4333333 3.622222
## 47 1.7000000 3.622222
## 48 1.9666667 3.622222
## 49 2.2333333 3.622222
## 50 2.5000000 3.622222
## 51 0.1000000 4.277778
## 52 0.3666667 4.277778
## 53 0.6333333 4.277778
## 54 0.9000000 4.277778
## 55 1.1666667 4.277778
## 56 1.4333333 4.277778
## 57 1.7000000 4.277778
## 58 1.9666667 4.277778
## 59 2.2333333 4.277778
## 60 2.5000000 4.277778
## 61 0.1000000 4.933333
## 62 0.3666667 4.933333
## 63 0.6333333 4.933333
## 64 0.9000000 4.933333
## 65 1.1666667 4.933333
## 66 1.4333333 4.933333
## 67 1.7000000 4.933333
## 68 1.9666667 4.933333
## 69 2.2333333 4.933333
## 70 2.5000000 4.933333
## 71 0.1000000 5.588889
## 72 0.3666667 5.588889
## 73 0.6333333 5.588889
## 74 0.9000000 5.588889
## 75 1.1666667 5.588889
## 76 1.4333333 5.588889
## 77 1.7000000 5.588889
## 78 1.9666667 5.588889
## 79 2.2333333 5.588889
## 80 2.5000000 5.588889
## 81 0.1000000 6.244444
## 82 0.3666667 6.244444
## 83 0.6333333 6.244444
## 84 0.9000000 6.244444
## 85 1.1666667 6.244444
## 86 1.4333333 6.244444
## 87 1.7000000 6.244444
## 88 1.9666667 6.244444
## 89 2.2333333 6.244444
## 90 2.5000000 6.244444
## 91 0.1000000 6.900000
## 92 0.3666667 6.900000
## 93 0.6333333 6.900000
## 94 0.9000000 6.900000
## 95 1.1666667 6.900000
## 96 1.4333333 6.900000
## 97 1.7000000 6.900000
## 98 1.9666667 6.900000
## 99 2.2333333 6.900000
## 100 2.5000000 6.900000
colnames(mp$prediction) = levels(iris$Species)
names(mp) = NULL
plt = melt(data.frame(mp), id.vars = c("Petal.Width", "Petal.Length"),
variable.name = "class", value.name = "prob")
ggplot(plt, aes(Petal.Width, Petal.Length, fill = prob)) + geom_raster() + facet_wrap(~ class)
Permutation importance is a Monte-Carlo method which estimates the importance of variables in determining predictions by computing the change in prediction error from repeatedly permuting the values of those variables.
permutationImportance
can compute this type of importance under arbitrary loss functions and contrast (between the loss with the unpermuted and permuted data).
permutationImportance(iris, "Sepal.Width", "Species", fit)
## [1] 0.01126667
For methods which generate predictions which are characters or unordered factors, the default loss function is the mean misclassification error. For all other types of predictions mean squared error is used.
It is, for example, possible to compute the expected change in the mean misclassification rate by class. The two arguments to loss.fun
are the permuted predictions and the target variable. In this case they are both vectors of factors.
contrast.fun
takes the output of loss.fun
on both the permuted and unpermuted predictions (x
corresponds to the permuted predictions and y
the unpermuted predictions).
This can, for example, be used to compute the mean misclassification error change on a per-class basis.
permutationImportance(iris, "Sepal.Width", "Species", fit,
loss.fun = function(x, y) {
mat = table(x, y)
n = colSums(mat)
diag(mat) = 0
rowSums(mat) / n
},
contrast.fun = function(x, y) x - y)
## setosa versicolor virginica
## 0.0000 0.0192 0.0126