Cet article est une suite de «[CGAN (GAN conditionnel) Génère MNIST (KMNIST)» (https://qiita.com/kyamada101/items/5195b1b32c60168b1b2f). C'est un record quand on essaie de faire ACGAN basé sur cGAN.
Je pense que c'est une idée naturelle en termes d'évolution depuis cGAN, mais quand je la mets en œuvre, c'est assez ...
Depuis que j'ai brièvement présenté cGAN dans l'article précédent, je vais expliquer brièvement ACGAN. ACGAN est, en un mot, ** "cGAN où le discriminateur effectue également des tâches de classification" **. C'est une méthode qui permet la sortie d'images avec plus de variations.
L'article original est [ici](Synthèse d'image conditionnelle avec les GAN de classificateur auxiliaire)
A. Odena, C. Olah, J. Shlens. Conditional Image Synthesis With Auxiliary Classifier GANs. CVPR, 2016
En ce qui concerne les articles d'ACGAN, certaines personnes ont publié les articles originaux, ce sera donc utile.
Article de référence
Explication des articles sur AC-GAN (synthèse d'image conditionnelle avec les GAN de classificateur auxiliaire)
Dans cGAN, l'image authentique / fausse et les informations d'étiquette étaient entrées dans Discriminator, et l'identification de l'authentique ou du faux était sortie. D'un autre côté, dans ACGAN, l'entrée de Discriminator n'est qu'une image, et pas seulement l'identification du vrai ou du faux mais aussi le jugement de classe pour deviner quelle classe il est ajouté à la sortie. Il ressemble à ce qui suit lorsqu'il est écrit dans un diagramme. La partie «classe» de la figure est la sortie de la classification prédite par Discriminator. Comme «label», il se présente sous la forme d'un vecteur de dimension de numéro de classe.
ACGAN a l'implémentation PyTorch sur GitHub. Avec cela comme référence, modifions l'implémentation de cGAN que j'ai écrit dans l'article précédent.
Que faire
Est presque tout. Ensuite, la structure de Discriminator ressemble à ceci. Il s'agit d'un dessin du diagramme de structure du discriminateur cGAN publié dans l'article précédent, mais la partie représentée en rouge est le changement d'ACGAN.
Une implémentation de Discriminator.
python
class Discriminator(nn.Module):
def __init__(self, num_class):
super(Discriminator, self).__init__()
self.num_class = num_class
self.conv = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), #L'entrée est 1 canal(Parce que c'est noir et blanc),Nombre de filtres 64,Taille de filtre 4*4
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(128),
)
self.fc = nn.Sequential(
nn.Linear(128 * 7 * 7, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
)
self.fc_TF = nn.Sequential(
nn.Linear(1024, 1),
nn.Sigmoid(),
)
self.fc_class = nn.Sequential(
nn.Linear(1024, num_class),
nn.LogSoftmax(dim=1),
)
self.init_weights()
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm1d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
def forward(self, img):
x = self.conv(img)
x = x.view(-1, 128 * 7 * 7)
x = self.fc(x)
x_TF = self.fc_TF(x)
x_class = self.fc_class(x)
return x_TF, x_class
Il semble y avoir différentes manières d'ajouter le résultat de la classification. Dans l'implémentation PyTorch du lien que j'ai posté plus tôt, la couche Linear
était bifurquée à la fin, donc je l'implémente de la même manière ici.
Selon ce changement, la fonction par époque ressemble à ceci.
python
def train_func(D_model, G_model, batch_size, z_dim, num_class, TF_criterion, class_criterion,
D_optimizer, G_optimizer, data_loader, device):
#Mode entraînement
D_model.train()
G_model.train()
#La vraie étiquette est 1
y_real = torch.ones((batch_size, 1)).to(device)
D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Etiquette de bruit à mettre en D
#La fausse étiquette est 0
y_fake = torch.zeros((batch_size, 1)).to(device)
D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Etiquette de bruit à mettre en D
#Initialisation de la perte
D_running_TF_loss = 0
G_running_TF_loss = 0
D_running_class_loss = 0
D_running_real_class_loss = 0
D_running_fake_class_loss = 0
G_running_class_loss = 0
#Calcul lot par lot
for batch_idx, (data, labels) in enumerate(data_loader):
#Ignorer si inférieur à la taille du lot
if data.size()[0] != batch_size:
break
#Création de bruit
z = torch.normal(mean = 0.5, std = 1, size = (batch_size, z_dim)) #Moyenne 0.Générer des nombres aléatoires selon une distribution normale de 5
real_img, label, z = data.to(device), labels.to(device), z.to(device)
#Mise à jour du discriminateur
D_optimizer.zero_grad()
#Mettre une image réelle dans Discriminator et propager vers l'avant ⇒ Calcul des pertes
D_real_TF, D_real_class = D_model(real_img)
D_real_TF_loss = TF_criterion(D_real_TF, D_y_real)
CEE_label = torch.max(label, 1)[1].to(device)
D_real_class_loss = class_criterion(D_real_class, CEE_label)
#Mettre l'image créée en mettant du bruit dans Generator dans Discriminator et propager vers l'avant ⇒ Calcul de la perte
fake_img = G_model(z, label)
D_fake_TF, D_fake_class = D_model(fake_img.detach()) #fake_Stop Loss calculé dans les images pour qu'il ne se propage pas vers Generator
D_fake_TF_loss = TF_criterion(D_fake_TF, D_y_fake)
D_fake_class_loss = class_criterion(D_fake_class, CEE_label)
#Minimiser la somme de deux pertes
D_TF_loss = D_real_TF_loss + D_fake_TF_loss
D_class_loss = D_real_class_loss + D_fake_class_loss
D_TF_loss.backward(retain_graph=True)
D_class_loss.backward()
D_optimizer.step()
D_running_TF_loss += D_TF_loss.item()
D_running_class_loss += D_class_loss.item()
D_running_real_class_loss += D_real_class_loss.item()
D_running_fake_class_loss += D_fake_class_loss.item()
#Mise à jour du générateur
G_optimizer.zero_grad()
#L'image créée en mettant du bruit dans le Générateur est placée dans le Discriminateur et propagée vers l'avant ⇒ La partie détectée devient Perte
fake_img_2 = G_model(z, label)
D_fake_TF_2, D_fake_class_2 = D_model(fake_img_2)
#G perte(max(log D)Optimisé avec)
G_TF_loss = -TF_criterion(D_fake_TF_2, y_fake)
G_class_loss = class_criterion(D_fake_class_2, CEE_label) #Du point de vue de G, ce serait bien s'il pensait que D était réel et lui donnait un cours.
G_TF_loss.backward(retain_graph=True)
G_class_loss.backward()
G_optimizer.step()
G_running_TF_loss += G_TF_loss.item()
G_running_class_loss -= G_class_loss.item()
D_running_TF_loss /= len(data_loader)
D_running_class_loss /= len(data_loader)
D_running_real_class_loss /= len(data_loader)
D_running_fake_class_loss /= len(data_loader)
G_running_TF_loss /= len(data_loader)
G_running_class_loss /= len(data_loader)
return D_running_TF_loss, G_running_TF_loss, D_running_class_loss, G_running_class_loss, D_running_real_class_loss, D_running_fake_class_loss
En plus des changements mentionnés précédemment, j'ai également modifié le bruit à ajouter. La dernière fois, il s'agissait d'une distribution normale avec 30 dimensions, moyenne 0,5 et écart type 0,2, mais cette fois, il s'agit d'une distribution normale avec 100 dimensions, moyenne 0,5 et écart type 1.
La perte de classification est torch.nn.NLLLoss ()
. Cela correspondait également à la mise en œuvre du lien plus tôt.
Le premier est le graphique des pertes. Dans ACGAN, il existe deux types de perte, la perte d'identification réelle ou fausse et la perte de classification, et les deux pertes sont propagées à la fois au générateur et au discriminateur. Il est également tracé séparément dans le graphique.
«T / F_loss» est la perte (ligne continue) pour une identification authentique / fausse, et «class_loss» est la perte (ligne pointillée) pour la classification.
En regardant cela, il semble que cela fonctionne. Pourtant... Ceci est un gif lorsqu'une image de chaque étiquette est générée pour chaque époque. J'ai entré les informations d'étiquette de sorte que la ligne supérieure soit "Ah, I, U ..." à partir de la gauche, et la partie inférieure droite est "..., N, ゝ". Il n'y a pratiquement aucune correspondance entre l'étiquette et l'image générée. Mais on dirait qu'il produit "du texte sur une autre étiquette" plutôt qu'une image complètement dénuée de sens.
Semblable à cGAN, j'ai essayé de générer 5 "A" à "ゝ" par Generator après un entraînement de 100 points. N'est-ce pas seulement "ke" qui semble correspondre à Label-chan? (Au contraire, le mode s'effondre complètement ...)
D'ailleurs, c'est le résultat de la génération de cGAN après 100 périodes d'entraînement dans les mêmes conditions. De toute évidence, cGAN génère des caractères plus proches de l'étiquette.
À première vue sur la sortie, chez ACGAN, Generator et Discriminator ** pensent que les caractères avec une forme différente sont les caractères de cette étiquette ** (Ex: Discriminator et Generator sont tous les deux "I" N'est-ce pas (celui qui ressemble à la forme est traité comme l'étiquette «A»)? J'ai pensé.
Il s'agit d'un graphique qui divise la perte de la classification du discriminateur en la perte dérivée de l'image réelle et la perte dérivée de la fausse image (= image créée par le générateur). sum_class_loss
est la valeur totale (= identique à la ligne pointillée rouge dans le graphique précédent).
En regardant ce graphique, Discriminator commet une erreur en jugeant l'image réelle (en particulier dans les premiers stades de l'apprentissage) et devine le jugement de la fausse image.
(En termes numériques, real_class_loss
est environ 20 fois la valeur au début de fake_class_loss
et environ 5 fois à la fin)
En d'autres termes, ** l'image créée par le Générateur avec l'étiquette «A» est traitée comme «A» par Discriminator même si la forme réelle est assez différente de «A» **. Je peux imaginer ça.
Idéalement peut-être, la perte de la classification devrait être à peu près la même pour les images réelles et fausses.
Comme mentionné dans l'article original d'ACGAN, il semble que ** s'il y a trop de classes, la qualité de l'image de sortie se détériorera sur le même réseau **. Dans l'article original, ImageNet (1000 classes) est divisé en 10 classes x 100 cas pour l'expérimentation.
Par conséquent, j'ai décidé d'essayer ceci dans 5 classes une fois.
Faisons la même structure de réseau et essayons de générer 5 caractères de "A" à "O".
Le graphique des pertes est similaire. Il semble qu'il y ait encore de la place pour que le T / F_loss
diminue.
Il y a aussi des inégalités ici, mais la seconde moitié est assez belle.
Ensuite, générons 5 images chacune après une formation de 100 époques.
Il semble que le mode ne s'effondre pas.
Ensuite, c'est la perte de la classification de Discriminateur.
Sur une base numérique, il y avait une différence d'environ 10 fois au début, mais c'est presque la même valeur dans l'étape finale, mais il est difficile de voir ce graphique, donc je ne l'afficherai qu'après 3 époques.
Si vous regardez ceci, vous pouvez voir que real_class_loss
et fake_class_loss
deviennent des valeurs assez proches.
En premier lieu, y a-t-il une différence de 10 à 20 fois entre la véritable classification et la fausse classification de la 1ère époque dans les premiers stades de l'apprentissage? ?? J'ai pensé, alors j'ai essayé d'afficher la perte pour chaque iter (pour chaque mini-lot). Il est vrai que la valeur de la perte ne change pas entre «real_class_loss» et «fake_class_loss» au début, mais vous pouvez voir que «fake_class_loss» baisse fortement.
J'ai essayé de n'entraîner que l'image réelle dans les premières époques, mais cela n'avait pas beaucoup de sens, alors j'ai décidé de ne pré-apprendre que la tâche de classification.
Obtenez uniquement le discriminateur et ne résolvez que la tâche de classification.
La convergence est assez rapide, donc je ne fais que 20 époques. En conséquence, il est subtil, mais pour le moment, j'utiliserai ce Discriminateur après une formation de 20 époques.
La perte vraie / fausse est presque la même que sans pré-apprentissage. Perte de classification Quant à elle, elle est devenue assez petite depuis le début.
Maintenant, regardons la perte de classification dérivée de l'image réelle et de la fausse image. J'ai essayé d'apprendre jusqu'à 300 époques. Par rapport au non pré-appris, la valeur de perte dérivée de l'image réelle est également considérablement inférieure. C'est environ quatre fois plus que la perte dérivée de la fausse image, mais ce n'est toujours pas la même valeur.
Jetons un coup d'oeil à l'image générée par ACGAN après cette formation de 300 époques. Hmm. .. Aucun effet n'est observé. Il n'y a pas d'augmentation du nombre de caractères réussis et une réduction du mode se produit.
Il existe plusieurs ensembles de données kuzuji qui ont un grand nombre de données par caractère, 6000 et seulement environ 300 à 400. Je pense que plus le nombre de données par classe est grand, mieux c'est, alors j'ai pensé que cela pourrait fonctionner si le nombre de données était supérieur à CIFAR-10, mais ce n'était pas bon.
Personnellement, la distance entre les caractères de chaque étiquette dans l'espace latent n'est-elle pas proche (= les caractères avec des étiquettes différentes sont assez proches dans l'espace latent)? Je pense. Dans l'expérience de l'article original, j'ai expérimenté CIFAR-10 et ImageNet toutes les 10 classes, mais dans le cas des caractères indésirables, il n'y avait qu'un peu plus de la moitié des personnages qui travaillaient dans 10 classes, et cela ne fonctionnait que dans 5 classes.
En tout cas, il semble assez difficile de viser et de sortir la classe 49 avec ACGAN, donc je vais abandonner ...
Recommended Posts