Validación cruzada de K-Fold | Guía para la validación cruzada de K-Fold en R

Contenidos

Requisitos previos: lenguaje de programación R básico y conocimiento básico de clasificación

Mientras que el enfoque del conjunto de validación funciona dividiendo el conjunto de datos una vez, k-Fold lo hace cinco o diez veces. Imagine que está haciendo el enfoque del conjunto de validación diez veces utilizando un grupo de datos diferente.

Digamos que tenemos 100 filas de datos. Los dividimos aleatoriamente en diez grupos de pliegues. Cada pliegue constará de alrededor de 10 filas de datos. El primer pliegue se utilizará como conjunto de validación y el resto es para el conjunto de entrenamiento. Luego entrenamos nuestro modelo usando este conjunto de datos y calculamos la precisión o pérdida. Luego repetimos este proceso pero usando un pliegue diferente para el conjunto de validación. Vea la imagen a continuación.

70068k-fold20cv-9297840

Validación cruzada de K-Fold. Imagen del autor

Saltemos al código

Las bibliotecas que usamos son estas dos:

library(tidyverse) 
library(caret)

Los datos utilizados aquí son datos de enfermedades cardíacas de la UCI que se pueden descargar en Kaggle. También puede utilizar cualquier dato de clasificación para este experimento.

data <- read.csv("../input/heart-disease-uci/heart.csv")
head(data)

Aquí están las seis filas superiores de los datos cargados. Tiene trece predictores y la última columna es la variable de respuesta. También puede verificar las últimas filas usando la función tail ().

55736screen20shot202021-03-1120at2018-53-15-6978641

La distribución de datos

Aquí queremos confirmar que la distribución entre los datos de dos etiquetas no es muy diferente. Porque los conjuntos de datos desequilibrados pueden conducir a una precisión desequilibrada. Esto significa que su modelo siempre predecirá hacia una sola etiqueta, o siempre predecirá 0 o 1.

hist(data$target,col="coral")
prop.table(table(data$target))
72182screen20shot202021-03-1420at2014-59-04-5300420

Este gráfico muestra que nuestro conjunto de datos está ligeramente desequilibrado pero aún lo suficientemente bueno. Tiene una proporción de 46:54. Debería empezar a preocuparse si su conjunto de datos tiene más del 60% de los datos en una clase. En ese caso, puede usar SMOTE para manejar un conjunto de datos desequilibrado.

El k-Fold

set.seed(100)
trctrl <- trainControl(method = "cv", number = 10, savePredictions=TRUE)
nb_fit <- train(factor(target) ~., data = data, method = "naive_bayes", trControl=trctrl, tuneLength = 0)
nb_fit

La primera línea es establecer la semilla del pseudoaleatorio para que se pueda reproducir el mismo resultado. Puede utilizar cualquier número para el valor inicial.

A continuación, podemos establecer la configuración de k-Fold en la función trainControl (). Establezca el parámetro del método en “cv” y el parámetro numérico en 10. Significa que establecemos la validación cruzada con diez pliegues. Podemos establecer el número de pliegue con cualquier número, pero la forma más común es establecerlo en cinco o diez.

La función train () se usa para determinar el método que usamos. Aquí usamos el método Naive Bayes y establecemos tuneLength en cero porque nos enfocamos en evaluar el método en cada pliegue. También podemos establecer tuneLength si queremos hacer el ajuste de parámetros durante la validación cruzada. Por ejemplo, si usamos el método K-NN y queremos analizar cuántos K son los mejores para nuestro modelo.

Puede ver el método admitido en Documentación R.

Tenga en cuenta que la validación cruzada de k-Fold puede tardar un poco porque ejecuta el proceso de formación diez veces.

96911screen20shot202021-03-1220at2022-04-09-5104282

Imprimirá el detalle en la consola una vez que esté terminado. La precisión que se muestra en la consola es la precisión promedio de todos los pliegues de entrenamiento. Podemos ver que nuestro modelo tiene una precisión promedio del 83%.

Despliegue el k-Fold

Podemos determinar que nuestro modelo está funcionando bien en cada pliegue si observamos la precisión de cada pliegue. Para hacer esto, asegúrese de configurar el savePredictions parámetro a TRUE en la función trainControl ().

pred <- nb_fit$pred
pred$equal <- ifelse(pred$pred == pred$obs, 1,0)
eachfold <- pred %>%                                        
  group_by(Resample) %>%                         
  summarise_at(vars(equal),                     
               list(Accuracy = mean))              
eachfold

Aquí está la tabla de precisión en cada pliegue.

43261screen20shot202021-03-1220at2022-04-18-3228982

También podemos trazarlo en el gráfico para que sea más fácil de analizar. En este caso, usamos la gráfica de caja para representar nuestras precisiones.

ggplot(data=eachfold, aes(x=Resample, y=Accuracy, group=1)) +
geom_boxplot(color="maroon") +
geom_point() +
theme_minimal()

84001screen20shot202021-03-1420at2015-27-02-3273577

Podemos ver que cada uno de los pliegues logra una precisión que no se diferencia mucho entre sí. La precisión más baja es 72,58%, y también en el diagrama de caja, no vemos valores atípicos. Lo que significa que nuestro modelo estaba funcionando bien en la validación cruzada de k veces.

Que sigue

  • Prueba con un número diferente de pliegues
  • Hacer un ajuste de parámetros
  • Utilice otros conjuntos de datos y métodos

Breve biografía del autor

Me llamo Muhammad Arnaldo, un entusiasta del aprendizaje automático y la ciencia de datos. Actualmente estudiante de maestría en ciencias de la computación en Indonesia.

Los medios que se muestran en este artículo no son propiedad de DataPeaker y se utilizan a discreción del autor.

Suscribite a nuestro Newsletter

No te enviaremos correo SPAM. Lo odiamos tanto como tú.