Erfahren Sie, wie Sie bei Verwendung von Dataset mit Pytorch eine Kreuzvalidierung durchführen.
Sie können torch.utils.data.dataset.Subset
verwenden, um einen Datensatz durch Angabe eines Index aufzuteilen. Kombinieren Sie dies mit dem scikit-learn sklearn.model_selection
.
train_test_split
Verwenden Sie sklearn.model_selection.train_test_split
, um den Index in train_index
und valid_index
aufzuteilen, und verwenden Sie Subset
, um den Datensatz aufzuteilen.
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import train_test_split
dataset = get_dataset()
train_index, valid_index = train_test_split(range(len(dataset)), test_size=0.3)
batch_size = 16
train_dataset = Subset(dataset, train_index)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset = Subset(dataset, valid_index)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)
#Code hier lernen
Verwenden Sie sklearn.model_selection.KFold
, um den Index in train_index
und valid_index
aufzuteilen, und verwenden Sie Subset
, um den Datensatz aufzuteilen.
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import KFold
dataset = get_dataset()
batch_size = 16
kf = KFold(n_splits=3)
for _fold, (train_index, test_index) in enumerate(kf.split(X)):
train_dataset = Subset(dataset, train_index)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset = Subset(dataset, valid_index)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)
#Code hier lernen
Wenn es sich um einen Klassenklassifizierungsdatensatz handelt, sollten Sie in der Lage sein, den Wert von "y" zu erhalten, indem Sie "Datensatz [:] [1]" ausführen, sodass Sie auch "Stratified KFold" ausführen können sollten.
Recommended Posts