The hardware and bandwidth for this mirror is donated by METANET, the Webhosting and Full Service-Cloud Provider.
If you wish to report a bug, or if you are interested in having us mirror your free-software or open-source project, please feel free to contact us at mirror[@]metanet.ch.

Getting Started with ensembleML

Overview

ensembleML provides a single, consistent API for ensemble machine learning in R. Regardless of which algorithm you choose, the core workflow is always:

em_fit()  ->  em_predict()  ->  em_evaluate()

Advanced usage adds:

em_cv()        # k-fold cross-validation (stability estimates)
em_tune()      # grid-search hyperparameter optimisation
em_compare()   # side-by-side algorithm comparison
em_importance() # feature importance
em_partial()   # partial dependence plots
em_confusion() # confusion matrix heatmap
em_calibration() # calibration / reliability diagram
em_residuals() # regression diagnostics

1. Train a model

data(iris)
set.seed(42)
idx   <- sample(nrow(iris), 120)
train <- iris[idx,  ]
test  <- iris[-idx, ]

rf <- em_fit(Species ~ ., data = train, method = "random_forest",
             verbose = TRUE)
#> [ensembleML] task auto-detected as 'classification'
#> 
#> ╭────────────────────────────────────────────────────╮
#> │  Algorithm:           random_forest               │
#> │  Task:                classification              │
#> │  Response:            Species                     │
#> │  Classes:             setosa, versicolor, virginica│
#> │  Predictors:          4  (Sepal.Length, Sepal.Width, Petal.Length, …)│
#> │  Training n:          120                         │
#> │  Fit time:            0.020 sec                   │
#> │  Train metrics:       accuracy=1.0000  kappa=1.0000  precision=1.0000  recall=1.0000  f1=1.0000  auc=NA│
#> │  ⚠  Use em_evaluate() on held-out data         │
#> ╰────────────────────────────────────────────────────╯

Switching algorithms requires changing a single argument:

xgb <- em_fit(Species ~ ., data = train, method = "xgboost")
ada <- em_fit(Species ~ ., data = train, method = "adaboost")
bag <- em_fit(Species ~ ., data = train, method = "bagging")

2. Predict

preds <- em_predict(rf, newdata = test)
head(preds)
#>      7     11     12     19     23     28 
#> setosa setosa setosa setosa setosa setosa 
#> Levels: setosa versicolor virginica

Class probabilities:

probs <- em_predict(rf, newdata = test, type = "prob")
head(probs, 3)
#>    setosa versicolor virginica
#> 7   1.000      0.000         0
#> 11  0.998      0.002         0
#> 12  1.000      0.000         0

3. Evaluate

em_evaluate(rf, newdata = test)
#>  accuracy     kappa precision    recall        f1       auc 
#>    0.9333    0.8997    0.9364    0.9364    0.9364        NA

Select specific metrics:

em_evaluate(rf, newdata = test, metrics = c("accuracy", "f1", "kappa"))
#> accuracy       f1    kappa 
#>   0.9333   0.9364   0.8997

4. Cross-validation

Use em_cv() to get mean +/- SD across folds before committing to a method:

cv_res <- em_cv(Species ~ ., data = iris, method = "random_forest",
                cv_folds = 5, repeats = 3)
cv_res$summary
em_plot_cv(cv_res, metric = "accuracy")

5. Tune hyperparameters

grid <- list(ntree = c(100, 300, 500), mtry = c(1, 2, 3))

tuned <- em_tune(
  Species ~ ., data = train, method = "random_forest",
  param_grid = grid, cv_folds = 5
)

tuned$best_params
tuned$best_score
tuned$results

6. Compare algorithms

cmp <- em_compare(Species ~ ., train = train, test = test)
cmp$table

7. Feature importance

em_importance(rf, top_n = 4)


8. Partial dependence

em_partial(rf, data = train, feature = "Petal.Length")

9. Confusion matrix

em_confusion(rf, newdata = test)
em_confusion(rf, newdata = test, normalise = TRUE)

10. Regression example

Everything works identically for numeric responses:

set.seed(7)
reg_data  <- data.frame(
  x1 = rnorm(200), x2 = rnorm(200),
  y  = 3 + 2 * rnorm(200) + rnorm(200))
reg_train <- reg_data[1:160, ]
reg_test  <- reg_data[161:200, ]

reg_model <- em_fit(y ~ ., data = reg_train, method = "random_forest")
#> [ensembleML] task auto-detected as 'regression'
em_evaluate(reg_model, reg_test)
#>    rmse     mae    mape     rsq adj_rsq 
#>  2.4320  1.8556 88.1007 -0.2193 -0.2852
em_residuals(reg_model, reg_test)
#> `geom_smooth()` using formula = 'y ~ x'


Citation

If you use ensembleML in published work, please cite it:

citation("ensembleML")

The individual algorithms should also be cited — see citation("ensembleML") for the full list of references.


Session info

sessionInfo()
#> R version 4.2.1 (2022-06-23 ucrt)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 26200)
#> 
#> Matrix products: default
#> 
#> locale:
#> [1] LC_COLLATE=C                          
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ensembleML_0.2.5
#> 
#> loaded via a namespace (and not attached):
#>  [1] bslib_0.10.0         compiler_4.2.1       pillar_1.11.1       
#>  [4] RColorBrewer_1.1-3   jquerylib_0.1.4      tools_4.2.1         
#>  [7] digest_0.6.39        lattice_0.20-45      nlme_3.1-168        
#> [10] jsonlite_2.0.0       evaluate_1.0.5       lifecycle_1.0.5     
#> [13] tibble_3.3.1         gtable_0.3.6         mgcv_1.8-40         
#> [16] pkgconfig_2.0.3      rlang_1.1.7          Matrix_1.6-5        
#> [19] cli_3.6.5            rstudioapi_0.18.0    yaml_2.3.12         
#> [22] xfun_0.57            fastmap_1.2.0        gridExtra_2.3       
#> [25] withr_3.0.2          dplyr_1.2.0          knitr_1.51          
#> [28] generics_0.1.4       sass_0.4.10          vctrs_0.7.2         
#> [31] grid_4.2.1           tidyselect_1.2.1     glue_1.7.0          
#> [34] R6_2.6.1             otel_0.2.0           rmarkdown_2.31      
#> [37] ggplot2_4.0.2        farver_2.1.2         magrittr_2.0.3      
#> [40] splines_4.2.1        scales_1.4.0         htmltools_0.5.9     
#> [43] randomForest_4.7-1.2 labeling_0.4.3       S7_0.2.1            
#> [46] cachem_1.1.0

These binaries (installable software) and packages are in development.
They may not be fully stable and should be used with caution. We make no claims about them.