[PYTHON] Comment augmenter les données avec PyTorch

Augmentation des données avec PyTorch

Voici comment gonfler vos données avec PyTorch. En ce qui concerne PyTorch lui-même, j'ai déjà écrit un article d'introduction sur le blog, alors veuillez vous référer à ce qui suit si vous le souhaitez.

Introduction au framework d'apprentissage profond en vedette "PyTorch"

Veuillez vous référer à l'article suivant pour les raisons de l'implémentation du remplissage de données et des exemples spécifiques.

Méthode de remplissage des données d'image (Data Augmentation) pour améliorer la précision de l'apprentissage en profondeur pour comprendre tout en jouant avec des matériaux gratuits

De plus, cet article est rédigé en supposant qu'il sera exécuté dans "Google Colaboratory (Google Colab)". Google Colab lui-même dépasse le cadre de cet article. Si vous ne savez pas, veuillez vous référer à l'article suivant.

Si vous utilisez Google Colaboratory, vous n'avez pas besoin de créer un environnement et vous pouvez apprendre Python gratuitement.

Le code utilisé dans cet article est résumé dans le cahier suivant.

pytorch_data_preprocessing.ipynb

Cliquez sur l'icône "Ouvrir dans Colab" au milieu pour l'ouvrir dans Google Colab et l'exécuter tel quel.

Traitement des données dans PyTorch

Tout d'abord, vérifions le traitement des données dans PyTorch.

Télécharger les données des enseignants

Téléchargez d'abord les données de l'enseignant. L'explication est omise.

!git clone https://github.com/karaage0703/janken_dataset datasets
!rm -rf /content/datasets/.git
!rm /content/datasets/LICENSE

Le répertoire a la structure suivante. Choki, gu, pa, chaque répertoire contient des images de formes de main choki, goo et par.

datasets
├── choki
├── gu
└── pa

Définissez dataset_root_dir comme suit:

dataset_root_dir = '/content/datasets'

Créer un jeu de données

Tout d'abord, importez les bibliothèques requises.

import torch
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import PIL

Utilisez ImageFolder pour charger les images dans le dossier en tant qu'ensemble de données.

dataset = datasets.ImageFolder(root=dataset_root_dir)

Vérification du jeu de données

Vous pouvez vérifier le contenu de l'ensemble de données avec getitem. (# Ci-dessous le résultat de l'exécution).

print(dataset.__getitem__(0))
print(dataset.__getitem__(100))
print(dataset.__getitem__(150))
# (<PIL.Image.Image image mode=RGB size=320x240 at 0x7F11DB6DC160>, 0)
# (<PIL.Image.Image image mode=RGB size=320x240 at 0x7F11DB6DCF28>, 1)
# (<PIL.Image.Image image mode=RGB size=320x240 at 0x7F12297D2C50>, 2)

Pour vérifier le contenu avec matplotlib, suivez les étapes ci-dessous.

image_numb = 6 #Veuillez spécifier un multiple de 3
for i in range(0, image_numb):
  ax = plt.subplot(image_numb / 3, 3, i + 1)
  plt.tight_layout()
  ax.set_title(str(i))
  plt.imshow(dataset[i][0])

data_01.png

torchvision.transforms Dans PyTorch, les transformations peuvent être utilisées pour prétraiter divers traitements d'images, y compris l'augmentation des données.

Pour une inversion horizontale / verticale typique, les transformations sont écrites sous la forme suivante.

data_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
])

Après cela, si vous le spécifiez dans l'argument de transformation d'ImageFolder, l'ensemble de données avec le traitement d'image spécifié dans les transformations sera défini.

dataset_augmentated = datasets.ImageFolder(root=dataset_root_dir, transform=data_transform)

Vérifions les données.

image_numb = 6 #Veuillez spécifier un multiple de 3
for i in range(0, image_numb):
  ax = plt.subplot(image_numb / 3, 3, i + 1)
  plt.tight_layout()
  ax.set_title(str(i))
  plt.imshow(dataset_augmentated[i][0])

data_02.png

C'est à l'envers.

Consultez le bloc-notes Google Colab pour obtenir des exemples d'autres fonctions de transformation. Des techniques telles que l'effacement aléatoire sont également implémentées en standard. Si vous voulez tout savoir, veuillez vous référer au document officiel.

implémentation d'albumentations

Il s'agit d'un moyen simple d'utiliser une bibliothèque pour l'augmentation des données appelée albumentations avec PyTorch.

Tout d'abord, installez les albums avec la commande suivante.

! pip install albumentations

Importez les bibliothèques requises.

import albumentations as albu
import numpy as np
from PIL import Image

Comme pour la transformation, j'aimerais utiliser Image Folder pour gonfler les données avec albumation, mais un peu de technique est nécessaire.

Vous pouvez facilement utiliser les fonctions d'albumations avec Image Folder en appliquant ce qui suit.

albu_transforms = albu.Compose([
  albu.RandomRotate90(p=0.5),
  albu.RandomGamma(gamma_limit=(85, 115), p=0.2),
])

def albumentations_transform(image, transform=albu_transforms):    
  if transform:
    image_np = np.array(image)
    augmented = transform(image=image_np)
    image = Image.fromarray(augmented['image'])
  return image

data_transform = transforms.Compose([
  transforms.Lambda(albumentations_transform),
])

dataset_augmentated = datasets.ImageFolder(root=dataset_root_dir, transform=data_transform)

Vérifions le contenu des données.

image_numb = 6 #Veuillez spécifier un multiple de 3
for i in range(0, image_numb):
  ax = plt.subplot(image_numb / 3, 3, i + 1)
  plt.tight_layout()
  ax.set_title(str(i))
  plt.imshow(dataset_augmentated[i][0])

data_albu.png

Vous pouvez voir que le traitement d'image des albumentations est terminé.

Après quelques recherches, lors de l'utilisation d'albumations, il semble que les ensembles de données soient souvent implémentés indépendamment sans utiliser ImageFolder, mais c'est une technique pratique lorsque vous voulez l'essayer facilement avec ImageFolder.

Vous pouvez découvrir les fonctionnalités des albumentations dans le Jupyter Notebook de albumentations-examples publié par @Kazuhito sur GitHub. Devenir.

De plus, le Jupyter Notebook de @ Kazuhito est modifié pour fonctionner avec Google Colab ci-dessous, donc si vous voulez réellement le déplacer de vos propres mains, veuillez vous y référer.

albumentations_examples.ipynb (version compatible Google Colab)

mixup Le référentiel GitHub suivant a été utile lors de l'utilisation de la méthode populaire de gonflement des données avec PyTorch en raison de ses performances.

hongyi-zhang/mixup

Pour plus de détails sur la façon de mélanger et de vérifier les données après le mélange, reportez-vous au bloc-notes Google Colab.

pytorch_data_preprocessing.ipynb

Dans le cas de Keras, les articles suivants peuvent être utiles.

Augmentation de la confusion chez Keras

Résumé

Nous avons résumé comment gonfler les données (Data Augmentation) avec PyTorch et comment vérifier les données. Veuillez nous indiquer s'il existe des fonctions plus pratiques ou des méthodes plus intelligentes.

Article associé

Déplacez-vous et vérifiez ce que vous faites avec l'augmentation des données de l'API de détection d'objets de TensorFlow

Recommended Posts

Comment augmenter les données avec PyTorch
Comment gérer les données déséquilibrées
Comment lire les données de problème avec Paiza
Augmentation des données avec openCV
Comment créer des exemples de données CSV avec hypothèse
Comment récupérer des données de courses de chevaux avec Beautiful Soup
Comment lire les données de séries chronologiques dans PyTorch
Afficher l'image après l'augmentation des données avec Pytorch
Comment mettre à jour avec SQLAlchemy?
Comment lancer avec Theano
Comment modifier avec SQLAlchemy?
Comment séparer les chaînes avec ','
[PyTorch] Augmentation des données pour la segmentation
Comment faire RDP sur Fedora31
Comment gérer les trames de données
Comment supprimer avec SQLAlchemy?
Comment utiliser xgboost: classification multi-classes avec des données d'iris
Comment récupérer des données d'image de Flickr avec Python
Comment convertir des données détenues horizontalement en données détenues verticalement avec des pandas
Comment obtenir plus de 1000 données avec SQLAlchemy + MySQLdb
Comment extraire des données qui ne manquent pas de valeur nan avec des pandas
Comment extraire des données qui ne manquent pas de valeur nan avec des pandas
[Python] Comment FFT des données mp3
Python: comment utiliser async avec
Comment lire les données de la sous-région e-Stat
Pour utiliser virtualenv avec PowerShell
Comment installer python-pip avec ubuntu20.04LTS
Comment démarrer avec Scrapy
Comment démarrer avec Python
Comment gérer l'erreur DistributionNotFound
Comment démarrer avec Django
Comment calculer la date avec python
Comment installer mysql-connector avec pip3
Comment INNER JOIN avec SQL Alchemy
Comment installer Anaconda avec pyenv
[Introduction à Python] Comment obtenir des données avec la fonction listdir
Comment collecter des données d'apprentissage automatique
Comment appeler PyTorch dans Julia
Comment extraire des fonctionnalités de données de séries chronologiques avec les bases de PySpark
Comment effectuer un traitement arithmétique avec le modèle Django
Comment titrer plusieurs figures avec matplotlib
Afficher l'image après l'augmentation des données avec PyTorch
Comment collecter des données Twitter sans programmation
Comment obtenir l'identifiant du parent avec sqlalchemy
Comment ajouter un package avec PyCharm
Comment utiliser OpenVPN avec Ubuntu 18.04.3 LTS
Comment utiliser Cmder avec PyCharm (Windows)
Convertir des données Excel en JSON avec python
Comment empêcher les mises à jour de paquets avec apt
Comment utiliser BigQuery en Python
Comment utiliser Ass / Alembic avec HtoA
Convertissez des données FX 1 minute en données 5 minutes avec Python
Comment gérer les erreurs de compatibilité d'énumération
Comment utiliser le japonais avec le tracé NLTK
Comment faire un test de sac avec python
Comment rechercher Google Drive dans Google Colaboratory