[PYTHON] Falten Sie den Pytorch-Datensatz in Schichten

Beim N-Falten eines vorhandenen "Datensatzes" werden 1,2, ... N, 1,2, ..., N, 1,2, ..., N, 1,2 ,. .. und teilen Sie den Datensatz. Es wird verwendet, um zu verhindern, dass die Daten beim Teilen von Zeitreihendaten auf einen bestimmten Monat oder eine bestimmte Jahreszeit verzerrt werden.

from torch.utils.data import Dataset


class LayeredFoldWrapper(Dataset):
    def __init__(self, dataset, n_splits=5, fold=0, valid=False):
        self.dataset = dataset
        self.n_splits = n_splits
        self.fold = fold
        self.valid = valid
        self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
        self.train_index = list(set(range(len(dataset))) - set(self.valid_index))

    def __len__(self):
        return len(self._get_index_list(self.valid))

    def __getitem__(self, i):
        return self.dataset.__getitem__(self._get_index_list(self.valid)[i])

    def _valid_index(self, N, n_splits, fold):
        """
        N:Anzahl der Gesamtdaten
        n_splits:Anzahl der Teilungen in Falte
        fold:Wert zur Angabe jeder Falte 0<=fold<=n_splits-1
        """
        assert(0<=fold<=n_splits-1)
        return range(n_splits - fold - 1, N+1, n_splits)

    def _get_index_list(self, valid):
        if valid:
            return self.valid_index
        else:
            return self.train_index

Recommended Posts

Falten Sie den Pytorch-Datensatz in Schichten
Datensatzvorbereitung für PyTorch
Implementieren Sie Style Transfer mit Pytorch
[Pytorch] Memo über Dataset / DataLoader
[PyTorch] Überprüfen Sie, ob sich Modell und Datensatz im Cuda-Modus befinden
Ich habe versucht, Pytorchs Datensatz zu erklären
Ich habe Gray Scale mit Pytorch geschrieben
LightGBM UserWarning: Verwenden von categoryical_feature in Dataset
So rufen Sie PyTorch in Julia an