Dans cet article, nous classerons les images de CIFAR-10 à l'aide de PyTorch. Suivez le Tutoriel officiel avec des commentaires. De plus, Python et l'apprentissage automatique sont des super débutants.
Un jeu de données d'image à 10 étiquettes largement utilisé dans le domaine de l'apprentissage automatique. airplane、automobile、bird、cat、deer、dog、frog、horse、ship、truck 10 étiquettes sont disponibles.
Site officiel émettra une commande d'installation en fonction de chaque environnement. Puisque je suis un macOS, exécutez ce qui suit pour installer.
pip install torch torchvision
#Importer NumPy, Matplotlib, PyTorch
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#ToTensor: Image en échelle de gris (RVB 0)~255 à 0~Normaliser à la plage de 1), Normaliser: valeur Z (moyenne RVB et écart type à 0).Normaliser avec 5)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#Télécharger les données d'entraînement
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
#Télécharger les données de test
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2)
#Ensemble de données d'entraînement: 50000 images RVB avec 32 pixels de hauteur et de largeur
print(trainset.data.shape)
(50000, 32, 32, 3)
#Jeu de données de test: 10000 images RVB avec 32 pixels de hauteur et de largeur
print(testset.data.shape)
(10000, 32, 32, 3)
#Consultez la liste des cours
print(trainset.classes)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
#Les classes sont souvent utilisées, alors gardez-les séparément
classes = trainset.classes
Dans le document officiel, ** avion a été redéfini comme avion ** et ** automobile a été redéfini comme voiture **. Pourquoi?
#Essayez d'afficher l'image téléchargée
def imshow(img):
#Dénormaliser
img = img / 2 + 0.5
# torch.Du type Tensor à numpy.Convertir en type ndarray
print(type(img)) # <class 'torch.Tensor'>
npimg = img.numpy()
print(type(npimg))
#Convertir la forme de (RVB, vertical, horizontal) à (vertical, horizontal, RVB)
print(npimg.shape)
npimg = np.transpose(npimg, (1, 2, 0))
print(npimg.shape)
#Afficher l'image
plt.imshow(npimg)
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
#Mettre en œuvre CNN
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
Définissez chaque couche avec init et connectez-les avec forward.
#Entropie croisée
criterion = nn.CrossEntropyLoss()
#Méthode de descente de gradient probabiliste
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#Entraîner
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
#Erreur de propagation de retour
loss.backward()
optimizer.step()
train_loss = loss.item()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
[1, 2000] loss: 2.164
[1, 4000] loss: 1.863
[1, 6000] loss: 1.683
[1, 8000] loss: 1.603
[1, 10000] loss: 1.525
[1, 12000] loss: 1.470
[2, 2000] loss: 1.415
[2, 4000] loss: 1.369
[2, 6000] loss: 1.363
[2, 8000] loss: 1.333
[2, 10000] loss: 1.314
[2, 12000] loss: 1.317
Finished Training
La valeur moyenne de la perte pour chaque 2000 mini-lot est sortie dans le journal.
#Enregistrer le modèle
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
Enregistrez le modèle dans le répertoire courant avec l'extension pth (PyTorch).
#Charger les données de test et afficher l'image et l'étiquette correcte
dataiter = iter(testloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
#Chargez le modèle enregistré et prédisez
net = Net()
net.load_state_dict(torch.load(PATH))
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
GroundTruth: truck cat airplane ship Predicted: truck horse airplane ship
Vous pouvez voir que les prédictions sont correctes sauf pour chat.
print(outputs)
value, predicted = torch.max(outputs, 1)
print(value)
print(predicted)
tensor([[ 0.7114, -2.2724, 0.1225, 0.9470, 2.1940, 1.8655, -2.6655, 4.1646,
-1.1001, -1.6991],
[-2.2453, -4.1017, 1.8291, 3.2079, 1.1242, 3.6712, 1.0010, 1.0489,
-3.2010, -1.9476],
[-3.0669, -3.8900, 0.9312, 3.5649, 2.7791, 1.5095, 2.1216, 1.5274,
-4.3077, -2.2234],
[-2.0948, -3.4640, 2.4833, 2.6210, 4.0590, 1.8350, 0.4924, 0.7212,
-3.5043, -2.4212]], grad_fn=<AddmmBackward>)
tensor([4.1646, 3.6712, 3.5649, 4.0590], grad_fn=<MaxBackward0>)
tensor([7, 5, 3, 4])
** torch.max ** renvoie la valeur maximale des sorties.
correct = 0
total = 0
#Calculer sans se souvenir du gradient (sans apprentissage)
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
Accuracy of the network on the 10000 test images: 60 %
Vous pouvez voir que le taux de réponse correct pour 10000 données de test est de 60%.
Vous trouverez ci-dessous une note personnelle pour les débutants en Python. Pas bon ** (prédit == étiquettes) .sum (). Item () ** Je ne savais pas comment écrire ceci, donc je vais me déconnecter et vérifier.
print(type((predicted == labels)))
print((predicted == labels).dtype)
print(type((predicted == labels).sum()))
print((predicted == labels).sum())
print((predicted == labels).sum().item())
# <class 'torch.Tensor'>
# torch.bool
# <class 'torch.Tensor'>
# tensor(2)
# 2
Je vois. Comparez chaque élément du tableau et utilisez sum () implémenté dans torch.Tensor pour calculer la valeur totale de true. Ensuite, item () implémenté dans torch.Tensor est utilisé pour faire de la valeur totale une valeur numérique de type int. C'était un peu plus facile à comprendre quand je l'ai vérifié avec numpy.
#Essayez avec numpy
a = np.array([1, 2, 3, 4, 5])
b = np.array([1, 2, 0, 4, 5])
print(type((a == b)))
print((a == b))
print((a == b).sum())
print(type((a == b).sum()))
print((a == b).sum().item())
print(type((a == b).sum().item()))
# <class 'numpy.ndarray'>
# [ True True False True True]
# 4
# <class 'numpy.int64'>
# 4
# <class 'int'>
En regardant Official, vous pouvez utiliser presque la même API que ndarray, donc ** sum () ** et ** item () ** Peut être utilisé. Convaincu.
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of airplane : 72 %
Accuracy of automobile : 66 %
Accuracy of bird : 38 %
Accuracy of cat : 58 %
Accuracy of deer : 60 %
Accuracy of dog : 29 %
Accuracy of frog : 73 %
Accuracy of horse : 60 %
Accuracy of ship : 69 %
Accuracy of truck : 73 %
Est-ce que c'est comme ça dans un tutoriel?