The hardware and bandwidth for this mirror is donated by METANET, the Webhosting and Full Service-Cloud Provider.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]metanet.ch.

Other types of models

In the following, we explain the counterfactuals workflow for both a classification and a regression task using concrete use cases.

library("counterfactuals")
library("iml")
library("rpart")

Other types of models

The Predictor class of the iml package provides the necessary flexibility to cover classification and regression models fitted with diverse R packages. In the introduction vignette, we saw models fitted with the mlr3 and randomForest packages. In the following, we show extensions to - an classification tree fitted with the caret package, the mlr (a predecesor of mlr3) and tidymodels. For each model we generate counterfactuals for the 100th row of the plasma dataset of the gamlss.data package using the WhatIf method.

data(plasma, package = "gamlss.data")
x_interest = plasma[100L,]

rpart - caret package

library("caret")
treecaret = caret::train(retplasma ~ ., data = plasma[-100L,], method = "rpart", 
  tuneGrid = data.frame(cp = 0.01))
predcaret = Predictor$new(model = treecaret, data = plasma[-100L,], y = "retplasma")
predcaret$predict(x_interest)
#>   .prediction
#> 1    342.9231
nicecaret = NICERegr$new(predcaret, optimization = "proximity", 
  margin_correct = 0.5, return_multiple = FALSE)
nicecaret$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [500, Inf] 
#>  
#> Head: 
#>      age    sex smokstat      bmi vituse calories   fat fiber alcohol cholesterol betadiet retdiet betaplasma
#>    <int> <fctr>   <fctr>    <num> <fctr>    <num> <num> <num>   <num>       <num>    <int>   <int>      <int>
#> 1:    46      1        3 35.25969      3   2667.5 131.6  10.1       0       550.5     1210    1291        218

rpart - tidymodels package

library("tidymodels")
treetm = decision_tree(mode = "regression", engine = "rpart") %>% 
  fit(retplasma ~ ., data = plasma[-100L,])
predtm = Predictor$new(model = treetm, data = plasma[-100L,], y = "retplasma")
predtm$predict(x_interest)
#>      .pred
#> 1 342.9231
nicetm = NICERegr$new(predtm, optimization = "proximity", 
  margin_correct = 0.5, return_multiple = FALSE)
nicetm$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [500, Inf] 
#>  
#> Head: 
#>      age    sex smokstat      bmi vituse calories   fat fiber alcohol cholesterol betadiet retdiet betaplasma
#>    <int> <fctr>   <fctr>    <num> <fctr>    <num> <num> <num>   <num>       <num>    <int>   <int>      <int>
#> 1:    46      1        3 35.25969      3   2667.5 131.6  10.1       0       550.5     1210    1291        218

rpart - mlr package

library("mlr")
task = mlr::makeRegrTask(data = plasma[-100L,], target = "retplasma")
mod = mlr::makeLearner("regr.rpart")

treemlr = mlr::train(mod, task)
predmlr = Predictor$new(model = treemlr, data = plasma[-100L,], y = "retplasma")
predmlr$predict(x_interest)
#>   .prediction
#> 1    342.9231
nicemlr = NICERegr$new(predmlr, optimization = "proximity", 
  margin_correct = 0.5, return_multiple = FALSE)
nicemlr$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [500, Inf] 
#>  
#> Head: 
#>      age    sex smokstat      bmi vituse calories   fat fiber alcohol cholesterol betadiet retdiet betaplasma
#>    <int> <fctr>   <fctr>    <num> <fctr>    <num> <num> <num>   <num>       <num>    <int>   <int>      <int>
#> 1:    46      1        3 35.25969      3   2667.5 131.6  10.1       0       550.5     1210    1291        218

Decision tree - rpart package

treerpart = rpart(retplasma ~ ., data = plasma[-100L,])
predrpart = Predictor$new(model = treerpart, data = plasma[-100L,], y = "retplasma")
predrpart$predict(x_interest)
#>       pred
#> 1 342.9231
nicerpart = NICERegr$new(predrpart, optimization = "proximity", 
  margin_correct = 0.5, return_multiple = FALSE)
nicerpart$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))
#> 1 Counterfactual(s) 
#>  
#> Desired outcome range: [500, Inf] 
#>  
#> Head: 
#>      age    sex smokstat      bmi vituse calories   fat fiber alcohol cholesterol betadiet retdiet betaplasma
#>    <int> <fctr>   <fctr>    <num> <fctr>    <num> <num> <num>   <num>       <num>    <int>   <int>      <int>
#> 1:    46      1        3 35.25969      3   2667.5 131.6  10.1       0       550.5     1210    1291        218

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.