This demonstration of the caret package was given by Mark Lawson, bioinformatician at Hemoshear LLC, Charlottesville VA. The caret package (short for Classification And REgression Training) is a set of functions that streamline the process for creating predictive models. The package contains tools for data splitting, pre-processing, feature selection, model tuning using resampling, and variable importance estimation. Read more about the caret package here.
This demonstration uses the caret package to split data into training and testing sets, and run repeated cross-validation to train random forest and penalized logistic regression models for classifying Fisher’s iris data.
library(caret)
## Loading required package: lattice
## Loading required package: ggplot2
set.seed(42)
# The iris dataset
data(iris)
head(iris)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species
## 1 5.1 3.5 1.4 0.2 setosa
## 2 4.9 3.0 1.4 0.2 setosa
## 3 4.7 3.2 1.3 0.2 setosa
## 4 4.6 3.1 1.5 0.2 setosa
## 5 5.0 3.6 1.4 0.2 setosa
## 6 5.4 3.9 1.7 0.4 setosa
summary(iris)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :4.300 Min. :2.000 Min. :1.000 Min. :0.100
## 1st Qu.:5.100 1st Qu.:2.800 1st Qu.:1.600 1st Qu.:0.300
## Median :5.800 Median :3.000 Median :4.350 Median :1.300
## Mean :5.843 Mean :3.057 Mean :3.758 Mean :1.199
## 3rd Qu.:6.400 3rd Qu.:3.300 3rd Qu.:5.100 3rd Qu.:1.800
## Max. :7.900 Max. :4.400 Max. :6.900 Max. :2.500
## Species
## setosa :50
## versicolor:50
## virginica :50
##
##
##
# look at the data
featurePlot(x = iris[, 1:4],
y = iris$Species,
plot = "pairs",
## Add a key at the top
auto.key = list(columns = 3))
# seperate train and test
trainIndex <- createDataPartition(iris$Species, # data labels
p = .7, # percentage used for training
list = FALSE, # return matrix instead of list
times = 1) # how many slices?
head(trainIndex)
## Resample1
## [1,] 1
## [2,] 3
## [3,] 4
## [4,] 5
## [5,] 6
## [6,] 7
# training data
train_data <- iris[trainIndex,1:4]
train_labels <- iris[trainIndex,5]
# test data
test_data <- iris[-trainIndex,1:4]
test_labels <- iris[-trainIndex,5]
table(train_labels)
## train_labels
## setosa versicolor virginica
## 35 35 35
table(test_labels)# pre process the data
## test_labels
## setosa versicolor virginica
## 15 15 15
preprocess_methods <- c("center", "scale")
# determine transformation values
preprocess <- preProcess(train_data,
method=preprocess_methods)
help(preProcess)
# apply transformation values
train_data.pre <- predict(preprocess, train_data)
test_data.pre <- predict(preprocess, test_data)
summary(train_data)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :4.300 Min. :2.20 Min. :1.100 Min. :0.100
## 1st Qu.:5.200 1st Qu.:2.80 1st Qu.:1.600 1st Qu.:0.300
## Median :5.700 Median :3.00 Median :4.200 Median :1.300
## Mean :5.821 Mean :3.07 Mean :3.748 Mean :1.202
## 3rd Qu.:6.400 3rd Qu.:3.30 3rd Qu.:5.100 3rd Qu.:1.800
## Max. :7.700 Max. :4.40 Max. :6.900 Max. :2.500
summary(train_data.pre)
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## Min. :-1.9887 Min. :-2.0331 Min. :-1.5096 Min. :-1.4469
## 1st Qu.:-0.8119 1st Qu.:-0.6302 1st Qu.:-1.2245 1st Qu.:-1.1843
## Median :-0.1581 Median :-0.1626 Median : 0.2579 Median : 0.1288
## Mean : 0.0000 Mean : 0.0000 Mean : 0.0000 Mean : 0.0000
## 3rd Qu.: 0.7571 3rd Qu.: 0.5389 3rd Qu.: 0.7711 3rd Qu.: 0.7853
## Max. : 2.4569 Max. : 3.1109 Max. : 1.7974 Max. : 1.7045
featurePlot(x = train_data.pre,
y = train_labels,
plot = "pairs",
## Add a key at the top
auto.key = list(columns = 3))
# train the model
# resampling methods / counts
train_control <- trainControl(method = "repeatedcv", # type of resampling
number = 10, # number of folds
repeats = 2) # repeats of whole process
# train a random forest
train_model.rf <- train(x=train_data.pre, # data
y=train_labels, # labels
method="rf", # classification method
trControl=train_control, # train control
metric="Accuracy", # metric to determine best model
tuneLength=3) # how many tuning parameters to try
train_model.rf
## Random Forest
##
## 105 samples
## 4 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 2 times)
## Summary of sample sizes: 94, 94, 94, 95, 94, 95, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.9510859 0.9262311
## 3 0.9465404 0.9196066
## 4 0.9465404 0.9196066
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
names(train_model.rf)
## [1] "method" "modelInfo" "modelType" "results"
## [5] "pred" "bestTune" "call" "dots"
## [9] "metric" "control" "finalModel" "preProcess"
## [13] "trainingData" "resample" "resampledCM" "perfNames"
## [17] "maximize" "yLimits" "times" "levels"
# train penalized logistic regression
train_model.plr <- train(x=train_data.pre,
y=train_labels,
method="plr",
trControl=train_control,
metric="Accuracy",
tuneLength=3)
train_model.plr
## Penalized Logistic Regression
##
## 105 samples
## 4 predictor
## 3 classes: 'setosa', 'versicolor', 'virginica'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 2 times)
## Summary of sample sizes: 94, 94, 95, 95, 95, 94, ...
## Resampling results across tuning parameters:
##
## lambda Accuracy Kappa
## 0e+00 0.6660606 0.4997727
## 1e-04 0.6660606 0.4997727
## 1e-01 0.6660606 0.4997727
##
## Tuning parameter 'cp' was held constant at a value of bic
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were lambda = 0.1 and cp = bic.
# results of the "best" model
train_model.rf$finalModel
##
## Call:
## randomForest(x = x, y = y, mtry = param$mtry)
## Type of random forest: classification
## Number of trees: 500
## No. of variables tried at each split: 2
##
## OOB estimate of error rate: 4.76%
## Confusion matrix:
## setosa versicolor virginica class.error
## setosa 35 0 0 0.00000000
## versicolor 0 33 2 0.05714286
## virginica 0 3 32 0.08571429
train_model.plr$finalModel
##
## Call:
## stepPlr::plr(x = x, y = y, weights = if (!is.null(wts)) wts else rep(1,
## length(y)), lambda = param$lambda, cp = as.character(param$cp))
##
## Coefficients:
## Intercept Sepal.Length Sepal.Width Petal.Length Petal.Width
## -3.24506 -1.32064 1.45670 -2.41441 -2.13541
##
## Null deviance: 133.67 on 104 degrees of freedom
## Residual deviance: 1.42 on 102.18 degrees of freedom
## Score: deviance + 4.7 * df = 14.53
# dive into the final model
class(train_model.rf$finalModel)
## [1] "randomForest"
names(train_model.rf$finalModel)
## [1] "call" "type" "predicted"
## [4] "err.rate" "confusion" "votes"
## [7] "oob.times" "classes" "importance"
## [10] "importanceSD" "localImportance" "proximity"
## [13] "ntree" "mtry" "forest"
## [16] "y" "test" "inbag"
## [19] "xNames" "problemType" "tuneValue"
## [22] "obsLevels" "param"
plot(train_model.rf$finalModel)
class(train_model.plr$finalModel)
## [1] "plr"
names(train_model.plr$finalModel)
## [1] "coefficients" "covariance" "deviance"
## [4] "null.deviance" "df" "score"
## [7] "nobs" "cp" "fitted.values"
## [10] "linear.predictors" "level" "call"
## [13] "xNames" "problemType" "tuneValue"
## [16] "obsLevels" "param"
# apply to test data
results.rf <- predict(train_model.rf, test_data.pre)
results.plr <- predict(train_model.plr, test_data.pre)
# side by side results
View(data.frame(test_labels,
results.rf,
results.plr))
# stats for the results
confusionMatrix(data=results.rf,
reference=test_labels)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 13 1
## virginica 0 2 14
##
## Overall Statistics
##
## Accuracy : 0.9333
## 95% CI : (0.8173, 0.986)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 0.8667 0.9333
## Specificity 1.0000 0.9667 0.9333
## Pos Pred Value 1.0000 0.9286 0.8750
## Neg Pred Value 1.0000 0.9355 0.9655
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.2889 0.3111
## Detection Prevalence 0.3333 0.3111 0.3556
## Balanced Accuracy 1.0000 0.9167 0.9333
confusionMatrix(data=results.plr,
reference=test_labels)
## Confusion Matrix and Statistics
##
## Reference
## Prediction setosa versicolor virginica
## setosa 15 0 0
## versicolor 0 15 15
## virginica 0 0 0
##
## Overall Statistics
##
## Accuracy : 0.6667
## 95% CI : (0.5105, 0.8)
## No Information Rate : 0.3333
## P-Value [Acc > NIR] : 5.001e-06
##
## Kappa : 0.5
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 0.0000
## Specificity 1.0000 0.5000 1.0000
## Pos Pred Value 1.0000 0.5000 NaN
## Neg Pred Value 1.0000 1.0000 0.6667
## Prevalence 0.3333 0.3333 0.3333
## Detection Rate 0.3333 0.3333 0.0000
## Detection Prevalence 0.3333 0.6667 0.0000
## Balanced Accuracy 1.0000 0.7500 0.5000
# variable importance
varImp(train_model.rf)
## rf variable importance
##
## Overall
## Petal.Width 100.00
## Petal.Length 89.18
## Sepal.Length 13.90
## Sepal.Width 0.00
train_model.rf$finalModel$importance
## MeanDecreaseGini
## Sepal.Length 6.457729
## Sepal.Width 2.358824
## Petal.Length 28.651485
## Petal.Width 31.840657