Kreuzvalidierung von K-Fold | Leitfaden zur Kreuzvalidierung von K-Fold in R

Inhalt

Bisherige Anforderungen: Grundlegende Programmiersprache R und Grundkenntnisse der Klassifikation

Während der Validierungssatzansatz funktioniert, indem der Datensatz einmal geteilt wird, k-Fold macht es fünf oder zehn Mal. Stellen Sie sich vor, Sie führen den Validierungssatzansatz zehnmal mit einem anderen Datensatz durch.

Sagen wir, wir haben 100 Datenzeilen. Wir teilen sie nach dem Zufallsprinzip in zehn Gruppen von Falten ein. Jede Falte besteht aus ca. 10 Datenzeilen. Die erste Faltung wird als Validierungssatz verwendet und der Rest ist für den Trainingssatz. Anschließend trainieren wir unser Modell mit diesem Datensatz und berechnen die Präzision oder den Verlust. Dann wiederholen wir diesen Vorgang, verwenden jedoch eine andere Faltung für den Validierungssatz. Siehe das Bild unten.

70068k-fold20cv-9297840

Kreuzvalidierung von K-Fold. Bild des Autors

Kommen wir zum Code

Die von uns verwendeten Bibliotheken sind diese beiden:

Bücherei(ordentlichversum) 
Bücherei(Caret)

Die hier verwendeten Daten sind Herzkrankheitsdaten der Intensivstation, die heruntergeladen werden können unter Kaggle. Sie können für dieses Experiment auch beliebige Klassifizierungsdaten verwenden.

Daten <- lesen.csv("../input/herzkrankheit-uci/heart.csv")
Kopf(Daten)

Hier sind die oberen sechs Zeilen der geladenen Daten. Es hat dreizehn Prädiktoren und die letzte Spalte ist die Antwortvariable. Sie können die letzten Zeilen auch mit der Schwanzfunktion überprüfen ().

55736screen20shot202021-03-1120at2018-53-15-6978641

Datenverteilung

Hier wollen wir bestätigen, dass die Verteilung zwischen den Daten zweier Labels nicht sehr unterschiedlich ist. Weil unausgeglichene Datensätze zu unausgeglichener Genauigkeit führen können. Das bedeutet, dass Ihr Modell immer auf ein einzelnes Label prognostiziert., oder wird immer vorhersagen 0 Ö 1.

hist(Daten$Ziel,col="Koralle")
prop.tabelle(Tisch(Daten$Ziel))
72182screen20shot202021-03-1420at2014-59-04-5300420

Diese Grafik zeigt, dass unser Datensatz etwas unausgeglichen, aber immer noch gut genug ist. Es hat ein Verhältnis von 46:54. Sie sollten sich Sorgen machen, wenn Ihr Datensatz mehr als 60% der Daten in einer Klasse. Dann, Sie können SMOTE verwenden, um einen unausgeglichenen Datensatz zu verarbeiten.

Die k-Falte

set.seed(100)
trctrl <- trainControl(Methode = "Lebenslauf", Zahl = 10, savePredictions=TRUE)
nb_fit <- Bahn(Faktor(Ziel) ~., Daten = Daten, Methode = "naiv_bayes", trControl=trctrl, tuneLänge = 0)
nb_fit

Die erste Zeile besteht darin, den Samen des Pseudozufalls so zu setzen, dass das gleiche Ergebnis reproduziert werden kann. Sie können eine beliebige Zahl für den Anfangswert verwenden.

Dann, wir können die k-Fold-Einstellung in der trainControl-Funktion einstellen (). Stellen Sie den Methodenparameter auf „cv“ und den numerischen Parameter auf 10. Das bedeutet, dass wir die Kreuzvalidierung mit zehn Falten festlegen. Wir können die Falznummer mit einer beliebigen Zahl festlegen, aber die gebräuchlichste Methode ist die Einstellung auf fünf oder zehn.

Die Zugfunktion () wird verwendet, um die von uns verwendete Methode zu bestimmen. Hier verwenden wir die Naive Bayes-Methode und setzen tuneLength auf Null, da wir uns darauf konzentrieren, die Methode für jede Falte zu bewerten. Wir können auch tuneLength einstellen, wenn wir Parametereinstellungen während der Kreuzvalidierung vornehmen möchten. Zum Beispiel, wenn wir die K-NN-Methode verwenden und analysieren möchten, wie viele K für unser Modell am besten sind.

Sie können die unterstützte Methode in . sehen Dokumentation R.

Bitte beachten Sie, dass die Kreuzvalidierung von k-Fold eine Weile dauern kann, da Sie den Trainingsprozess zehn Mal durchlaufen.

96911screen20shot202021-03-1220at2022-04-09-5104282

Es druckt die Details an die Konsole, sobald es fertig ist. Die auf der Konsole angezeigte Genauigkeit ist die durchschnittliche Genauigkeit aller Trainingsfalten. Wir können sehen, dass unser Modell eine durchschnittliche Genauigkeit von hat 83%.

Entfalte die K-Falte

Wir können feststellen, dass unser Modell in jeder Falte gut abschneidet, indem wir die Präzision jeder Falte betrachten.. Um dies zu tun, stellen Sie sicher, dass speichernVorhersagen Parameter auf TRUE in der trainControl-Funktion ().

pred <- nb_fit$pred
pred$gleich <- ansonsten(pred $ pred == pred $ obs, 1,0)
jedesmal <- Vor%>%                                        
  gruppiere nach(Resample) %>%                         
  zusammenfassen_at(deren(gleich),                     
               aufführen(Genauigkeit = Mittelwert))              
jedesmal

Hier ist die Präzisionstabelle in jeder Falte.

43261screen20shot202021-03-1220at2022-04-18-3228982

Wir können es auch in die Grafik einzeichnen, um die Analyse zu erleichtern. In diesem Fall, wir verwenden den Boxplot, um unsere Genauigkeiten darzustellen.

ggplot(data=eachfold, aes(x=Resample, y=Genauigkeit, Gruppe=1)) +
geom_boxplot(Farbe="kastanienbraun") +
geom_point() +
thema_minimal()

84001screen20shot202021-03-1420at2015-27-02-3273577

Wir können sehen, dass jede der Falten eine Präzision erreicht, die sich nicht viel voneinander unterscheidet. Die niedrigste Genauigkeit ist 72,58%, und auch im Boxplot, Wir sehen keine Ausreißer. Das bedeutet, dass unser Modell bei der Kreuzvalidierung von k mal gut funktioniert hat.

Was kommt als nächstes

  • Probieren Sie eine andere Anzahl von Falten aus
  • Nehmen Sie eine Parametereinstellung vor
  • Andere Datensätze und Methoden verwenden

Kurzbiographie des Autors

Ich heiße Muhammad Arnold, ein Enthusiast für maschinelles Lernen und Data Science. Derzeit Masterstudent in Informatik in Indonesien.

Die in diesem Artikel gezeigten Medien sind nicht Eigentum von DataPeaker und werden nach Ermessen des Autors verwendet.

Abonniere unseren Newsletter

Wir senden Ihnen keine SPAM-Mail. Wir hassen es genauso wie du.