When N-folding an existing Dataset, 1,2, ... N, 1,2, ..., N, 1,2, ..., N, 1,2 ,. .. and split the dataset. It is used to prevent the data from being biased to a specific month or season when dividing time series data.
from torch.utils.data import Dataset
class LayeredFoldWrapper(Dataset):
    def __init__(self, dataset, n_splits=5, fold=0, valid=False):
        self.dataset = dataset
        self.n_splits = n_splits
        self.fold = fold
        self.valid = valid
        self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
        self.train_index = list(set(range(len(dataset))) - set(self.valid_index))
    def __len__(self):
        return len(self._get_index_list(self.valid))
    def __getitem__(self, i):
        return self.dataset.__getitem__(self._get_index_list(self.valid)[i])
    def _valid_index(self, N, n_splits, fold):
        """
        N:Number of total data
        n_splits:Number of fold splits
        fold:Value to specify each fold 0<=fold<=n_splits-1
        """
        assert(0<=fold<=n_splits-1)
        return range(n_splits - fold - 1, N+1, n_splits)
    def _get_index_list(self, valid):
        if valid:
            return self.valid_index
        else:
            return self.train_index
        Recommended Posts