Lors du pliage N d'un Dataset
existant, 1,2, ... N, 1,2, ..., N, 1,2, ..., N, 1,2,. .. et divisez l'ensemble de données. Il est utilisé pour éviter que les données ne soient biaisées vers un mois ou une saison spécifique lors de la division des données de séries chronologiques.
from torch.utils.data import Dataset
class LayeredFoldWrapper(Dataset):
def __init__(self, dataset, n_splits=5, fold=0, valid=False):
self.dataset = dataset
self.n_splits = n_splits
self.fold = fold
self.valid = valid
self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
self.train_index = list(set(range(len(dataset))) - set(self.valid_index))
def __len__(self):
return len(self._get_index_list(self.valid))
def __getitem__(self, i):
return self.dataset.__getitem__(self._get_index_list(self.valid)[i])
def _valid_index(self, N, n_splits, fold):
"""
N:Nombre total de données
n_splits:Nombre de divisions dans le pli
fold:Valeur pour spécifier chaque pli 0<=fold<=n_splits-1
"""
assert(0<=fold<=n_splits-1)
return range(n_splits - fold - 1, N+1, n_splits)
def _get_index_list(self, valid):
if valid:
return self.valid_index
else:
return self.train_index
Recommended Posts