[PYTHON] Mise en œuvre de l'apprentissage continu en utilisant la distance Maharanobis de l'espace des fonctionnalités

Apprentissage continu

L'apprentissage continu signifie que le modèle apprend continuellement de nouvelles données les unes après les autres pendant une longue période de temps. Pour plus de détails, reportez-vous à Slide qui résume l'apprentissage continu.

Cette fois, nous mettrons en œuvre un article d'apprentissage continu en utilisant la méthode de détection hors distribution. Titre de l'article: Un cadre unifié simple pour détecter les échantillons hors distribution et les attaques contradictoires Résumé de l'article: Les données hors distribution sont considérées comme les données de la nouvelle classe, l'ajustement gaussien est effectué sur l'espace des caractéristiques du modèle profond et les données de test sont classées en fonction de la distance Maharanobis entre le vecteur moyen de l'ancienne classe et la nouvelle classe. L'explication détaillée de l'article a été écrite dans la deuxième partie de Slide.

La force de cet article est que ** l'apprentissage continu est possible sans compromettre la précision des classes apprises jusqu'à présent **, et il est nécessaire de réapprendre les paramètres de DNN en utilisant les données de la classe nouvellement ajoutée. Il n'y a pas **. Par conséquent, même si le nombre de classes augmente, l'apprentissage peut être effectué très rapidement.

Paramètres du problème de mise en œuvre

Explication de la mise en œuvre

Première préparation

import os 

import numpy as np
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import transforms as T
from torchvision.datasets import CIFAR10

# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
#Placez le modèle et les utils dans le référentiel ci-dessus dans la même hiérarchie
from model import EfficientNet 
from tqdm import tqdm
plt.style.use("ggplot")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

J'ai préparé une fonction qui renvoie un dataloader pour que je puisse ajouter une classe de CIFAR10

def return_data_loader(classes, train=True, batch_size=128):
    transform = []
    transform.append(T.Resize((64, 64))) #Besoin de redimensionner pour utiliser efmodel
    transform.append(T.ToTensor())
    transform = T.Compose(transform)

    dataset = CIFAR10("./data", train=train, download=True, transform=transform)
    targets = np.array(dataset.targets)
    mask = np.array([t in classes for t in targets])
    dataset.data = dataset.data[mask]
    dataset.targets = targets[mask]
    
    data_loader = DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=train)
    return data_loader

Utilisez le plus petit modèle d'efficacenet comme modèle

Tout d'abord, comme la flèche bleue fait partie de la figure, nous apprenons généralement 5 classes de modèles de discrimination. Ensuite, en tant que partie flèche rouge, rapprochez l'entité précédente avec une distribution gaussienne pour chaque classe. スクリーンショット 2020-10-05 9.08.56.png

NCLASS = 5 #Classe initiale
classes = np.arange(NCLASS)
model = 'efficientnet-b0'
weight_dir = "."

clf = EfficientNet.from_name(model)
clf._fc = torch.nn.Linear(clf._fc.in_features, NCLASS)
clf = clf.to(device)
clf.train()
train_loader = return_data_loader(classes=classes, train=True)

lr = 0.001
epoch_num = 50
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(clf.parameters(), lr=lr)

for epoch in tqdm(range(epoch_num)):
    train_loss = 0
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
        logit = clf(x)
        loss = criterion(logit, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    train_loss /= len(train_loader.dataset)
torch.save(clf.state_dict(), os.path.join(weight_dir, 'weight.pth'))

test_loader = return_data_loader(range(10), train=False)
clf.load_state_dict(torch.load(os.path.join(weight_dir, 'weight.pth')))
clf.eval()

pred = []
true = []
for x, y in test_loader:
    with torch.no_grad():
            pred.extend(clf(x.to(device)).max(1)[1].detach().cpu().numpy())
            true.extend(y.numpy())
print(accuracy_score(true, pred))
print(confusion_matrix(true, pred))

Tout d'abord, nous sortons la matrice mixte et le taux de précision lorsque le modèle de discrimination a été normalement entraîné. Puisque seules les classes 0 à 4 sont utilisées pour l'apprentissage, il est naturellement impossible de prédire les classes 5 à 9 et les classes 0 à 4 sont prédites de force.

0.4279 #Taux de réponse correct
[[877  27  47  35  14   0   0   0   0   0]
 [ 14 972   3   8   3   0   0   0   0   0]
 [ 51   7 785  81  76   0   0   0   0   0]
 [ 20  18 107 780  75   0   0   0   0   0]
 [ 13   2  58  62 865   0   0   0   0   0]
 [ 13  12 226 640 109   0   0   0   0   0]
 [ 26  55 232 477 210   0   0   0   0   0]
 [ 47  21 188 230 514   0   0   0   0   0]
 [604 214  53  95  34   0   0   0   0   0]
 [160 705  43  78  14   0   0   0   0   0]]

Ensuite, calculez la moyenne et la covariance des caractéristiques pour la mise en œuvre de la partie flèche rouge dans la figure ci-dessus. スクリーンショット 2020-10-05 9.29.18.png

def ext_feature(x):
    z = clf.extract_features(x)
    z = clf._avg_pooling(z)
    z = z.flatten(start_dim=1)
    return z.detach().cpu().numpy()

train_loaders = [return_data_loader(classes=[c], train=True) for c in range(10)]

z_mean = []
z_var = 0
target_count = []

for c in tqdm(range(NCLASS)): #Classe existante
    N = len(train_loaders[c].dataset) #Tenant le numéro de chaque classe
    target_count.append(N)
    
    with torch.no_grad():
        #Calcul moyen
        new_z_mean = 0
        for x, _ in train_loaders[c]:
            x = x.to(device)
            new_z_mean += ext_feature(x).sum(0) / N
        z_mean.append(new_z_mean)

        #Calcul de la variance
        for x, _ in train_loaders[c]:
            x = x.to(device)    
            z_var += (ext_feature(x) - new_z_mean).T.dot(ext_feature(x) - new_z_mean) / N

C = len(z_mean)
z_var /=  C
z_mean = np.array(z_mean)
target_count = np.array(target_count)

Une fois la moyenne et la co-dispersion obtenues, il est possible de classer sans la couche entièrement connectée de la couche finale en utilisant la distance de Maharanobis. L'implémentation utilise le théorème bayésien スクリーンショット 2020-10-05 9.32.07.png Où $ \ beta_c $ est le nombre de données de classe

z_var_inv = np.linalg.inv(z_var + np.eye(z_mean.shape[1])*1e-6)  
#Ajouter une régularisation pour éviter que la matrice inverse ne devienne instable
A = z_mean.dot(z_var_inv) #Un élément du contenu de l'exp de la molécule
B = (A*z_mean).sum(1) * 0.5 #2 éléments
beta = np.log(target_count) #3 éléments

accs = []
pred = []
true = []
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        pred.extend((A.dot(ext_feature(x).T) - B[:, None] + beta[:, None]).argmax(0))
        true.extend(y.numpy())
acc = accuracy_score(true, pred)        
print(acc)
accs.append(acc)
confusion_matrix(true, pred)

À partir des résultats suivants, il a été constaté que presque le même taux de précision peut être obtenu sans utiliser la couche entièrement connectée.

0.4273 #Taux de réponse correct
array([[899,  17,  43,  29,  12,   0,   0,   0,   0,   0],
       [ 25, 958,   6,   9,   2,   0,   0,   0,   0,   0],
       [ 55,   6, 785,  86,  68,   0,   0,   0,   0,   0],
       [ 29,  15, 109, 773,  74,   0,   0,   0,   0,   0],
       [ 23,   2,  55,  62, 858,   0,   0,   0,   0,   0],
       [ 22,   6, 227, 641, 104,   0,   0,   0,   0,   0],
       [ 34,  39, 256, 468, 203,   0,   0,   0,   0,   0],
       [ 71,  16, 199, 214, 500,   0,   0,   0,   0,   0],
       [653, 182,  53,  84,  28,   0,   0,   0,   0,   0],
       [221, 645,  42,  78,  14,   0,   0,   0,   0,   0]])

Mise en œuvre de l'apprentissage continu

Le but est de classer les données de test pour toutes les classes en fonction de la moyenne et de la variance des nouvelles données, sans entraîner les paramètres du modèle.

Le contour de l'algorithme est le suivant スクリーンショット 2020-10-05 9.39.07.png

for c in tqdm(range(NCLASS, 10)): #Nouvelle classe
    N = len(train_loaders[c].dataset)

    with torch.no_grad():
        #Calcul moyen
        new_z_mean = 0        
        for x, _ in train_loaders[c]:
            x = x.to(device)
            new_z_mean += ext_feature(x).sum(0) / N 


        #Calcul de la variance
        new_z_var = 0
        for x, _ in train_loaders[c]:
            x = x.to(device)    
            new_z_var += (ext_feature(x) - new_z_mean).T.dot(ext_feature(x) - new_z_mean) / N

    #Mettre à jour la moyenne et la variance
    C = len(target_count)
    z_mean = np.concatenate([z_mean, new_z_mean[None, :]])
    z_var = z_var*C/(C+1) + new_z_var/(C+1)
    target_count = np.append(target_count, N)

    z_var_inv = np.linalg.inv(z_var + np.eye(z_mean.shape[1])*1e-6)
    A = z_mean.dot(z_var_inv) 
    B = (A*z_mean).sum(1) * 0.5
    beta = np.log(target_count)
    pred = []
    true = []
    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            pred.extend((A.dot(ext_feature(x).T) - B[:, None] + beta[:, None]).argmax(0))
            true.extend(y.numpy())
    acc = accuracy_score(true, pred)
    accs.append(acc)
    print(acc)

Le résultat final est

0.4974 #Taux de réponse correct
array([[635,   1,  18,   4,   2,  14,   9,  36, 260,  21],
       [  1, 761,   0,   1,   0,   0,   8,   3,  21, 205],
       [ 20,   0, 581,  12,   8,  97, 105, 135,  35,   7],
       [  5,   0,  22, 450,  13, 256, 147,  60,  29,  18],
       [  2,   1,  16,  10, 555,  30,  63, 302,  20,   1],
       [  1,   0,  57, 288,  22, 325, 173, 106,  22,   6],
       [  5,   0,  49, 139,  36, 182, 350, 161,  35,  43],
       [  5,   2,  34,  50, 131, 104, 158, 446,  58,  12],
       [226,  26,  13,  11,   3,  22,  58,  41, 430, 170],
       [ 17, 250,   6,   5,   0,   8,  69,  16, 188, 441]])
plt.title("accuracy")
plt.plot(accs)
plt.show()

L'axe des x signifie le nombre de classes ajoutées. On peut voir que le taux de réponse correct lorsque 10 classes de données d'apprentissage sont finalement données est environ 0,1 plus élevé que lorsque seulement 5 classes sont données. download-6.png

Recommended Posts

Mise en œuvre de l'apprentissage continu en utilisant la distance Maharanobis de l'espace des fonctionnalités
Calculer la distance de Maharanobis en tenant compte de la corrélation des quantités d'entités à l'aide de scipy
Règles d'apprentissage Widrow-Hoff implémentées en Python
Implémentation des règles d'apprentissage Perceptron en Python
Modèle de reconnaissance d'image utilisant l'apprentissage profond en 2016
Astuces de fourniture de données utilisant deque dans l'apprentissage automatique