Beim N-Falten eines vorhandenen "Datensatzes" werden 1,2, ... N, 1,2, ..., N, 1,2, ..., N, 1,2 ,. .. und teilen Sie den Datensatz. Es wird verwendet, um zu verhindern, dass die Daten beim Teilen von Zeitreihendaten auf einen bestimmten Monat oder eine bestimmte Jahreszeit verzerrt werden.
from torch.utils.data import Dataset
class LayeredFoldWrapper(Dataset):
def __init__(self, dataset, n_splits=5, fold=0, valid=False):
self.dataset = dataset
self.n_splits = n_splits
self.fold = fold
self.valid = valid
self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
self.train_index = list(set(range(len(dataset))) - set(self.valid_index))
def __len__(self):
return len(self._get_index_list(self.valid))
def __getitem__(self, i):
return self.dataset.__getitem__(self._get_index_list(self.valid)[i])
def _valid_index(self, N, n_splits, fold):
"""
N:Anzahl der Gesamtdaten
n_splits:Anzahl der Teilungen in Falte
fold:Wert zur Angabe jeder Falte 0<=fold<=n_splits-1
"""
assert(0<=fold<=n_splits-1)
return range(n_splits - fold - 1, N+1, n_splits)
def _get_index_list(self, valid):
if valid:
return self.valid_index
else:
return self.train_index
Recommended Posts