Comme le titre l'indique, j'ai implémenté Single Shot Multibox Detector (SSD) avec PyTorch
([https://github.com] / jjjkkkjjj / pytorch_SSD](https://github.com/jjjkkjjj/pytorch_SSD)) Cependant, le calcul est calculé par rapport à ssd.pytorch etc. C'est lent (je vais enquêter sur la cause plus tard). </ strike> [^ 1] Cependant, j'ai travaillé dur sur l'abstraction, donc je pense qu'elle est hautement personnalisable. (Je ne sais pas si je peux documenter comment l'utiliser correctement ...). Si vous vérifiez la mise en œuvre du SSD, vous en trouverez beaucoup, mais cette fois,
[^ 1]: Je n'ai tout simplement pas défini l'argument d'initialisation num_workers
de DataLoader
. .. .. Actuellement, il est aussi rapide que ssd.pytorch.
――Je voulais le personnaliser librement
Je l'ai implémenté pour cette raison. Il existe de nombreuses explications faciles à comprendre (reference) dans de nombreux articles, mais j'aimerais les résumer à ma manière pour organiser mon esprit. Donc, cette fois, je voudrais résumer autour de l'ensemble de données.
PyTorch
Tout d'abord, j'aborderai brièvement le cadre d'apprentissage profond PyTorch
utilisé cette fois. En fait, quand j'ai essayé d'implémenter SSD pour la première fois, j'ai utilisé Tensorflow
. cependant,
compat.v1
, compat.v2
, etc.?)Par conséquent, la mise en œuvre ne s'est pas déroulée facilement. En particulier, «difficile à déboguer» m'a été fatal et je ne comprenais pas bien la traduction. Eh bien, je pense qu'il serait utile d'apprendre à l'utiliser simplement parce que je n'ai pas compris Tensorflow
. .. ..
Je pensais qu'il ne serait pas terminé pour le reste de ma vie, alors je l'ai changé en Pytorch
, qui a la particularité de pouvoir effectuer des opérations similaires à Numpy
. Lorsque je suis passé à PyTorch
, la" difficulté de débogage "que je ressentais dans Tensorflow a été considérablement améliorée et la mise en œuvre s'est bien déroulée. Eh bien, le fonctionnement de Numpy
est Matft (https://github.com/jjjkkkjjj/Matft), [J'ai créé une bibliothèque de calcul matriciel à N dimensions Matft avec Swift] Familier avec l'implémentation de (https://qiita.com/jjjkkjjj/items/1f2b5c3835b1600d3129)) et MILES (https://github.com/jjjkkjjj/MIL) Parce que ça l'était, «PyTorch» était parfait pour moi.
En premier lieu, je voudrais aborder brièvement ce qu'est le SSD, puis le résumer en détail. SSD est un algorithme de détection d'objets qui peut prédire la position et l'étiquette d'un objet de bout en bout. C'est un chiffre approprié, mais si vous donnez une image d'entrée comme celle-ci, le SSD affichera la position et l'étiquette de l'objet à la fois.
Ce que fait et fait ce modèle est comme suit. Je vais vous expliquer étape par étape.
Dans le papier SSD original, les ensembles de données sont PASCAL VOC2007, [PASCAL VOC2012](http: //host.robots.ox). .ac.uk / pascal / VOC / voc2012 /), COCO2014 sont utilisés. COCO n'est pas encore implémenté, je vais donc vous expliquer le jeu de données VOC. </ strike> Tout d'abord, parlons du jeu de données VOC.
La structure du répertoire est fondamentalement unifiée et ressemble à ce qui suit.
répertoire voc
$ tree -I '*.png|*.xml|*.jpg|*.txt'
└── VOCdevkit
└── VOC20**
├── Annotations
├── ImageSets
│ ├── Action
│ ├── Layout
│ ├── Main
│ └── Segmentation
├── JPEGImages
├── SegmentationClass
└── SegmentationObject
La détection d'objet nécessite des «Annotations», «JPEGImages» et «ImagesSets / Main» directement sous «VOC20 **». Chacun est comme suit.
.xml
des données d'annotation. C'est un avec le fichier .jpeg
de JPEGImages
..jpeg
de l'image. C'est un avec le fichier .xml
de ʻAnnotations`..txt
qui représente des informations sur l'ensemble de données. Les noms de fichiers ʻAnnotation et
JPEGImages`, qui sont les éléments de l'ensemble, sont décrits..xml
)Le fichier .xml
sous le répertoire ʻAnnotaions` est le suivant.
Annotations/~~.xml
<annotation>
<folder>VOC2007</folder>
<filename>000005.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>325991873</flickrid>
</source>
<owner>
<flickrid>archintent louisville</flickrid>
<name>?</name>
</owner>
<size>
<width>500</width>
<height>375</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>chair</name>
<pose>Rear</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>263</xmin>
<ymin>211</ymin>
<xmax>324</xmax>
<ymax>339</ymax>
</bndbox>
</object>
<object>
<name>chair</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>165</xmin>
<ymin>264</ymin>
<xmax>253</xmax>
<ymax>372</ymax>
</bndbox>
</object>
...
</annotation>
Les points importants sont les suivants.
<filename>
.jpeg
auquel ces données d'annotation correspondent<object>
<name>
--Nom de l'étiquette<truncated>
<difficult>
<bndbox>
--Bounding box (position de l'objet). notation des coinsPour implémenter un ensemble de données, vous devez étendre la classe Dataset
. Ensuite, il est nécessaire d'implémenter «len» qui renvoie le nombre d'ensembles de données et «getitem» qui renvoie les données d'entrée et l'étiquette de réponse correcte pour «index» dans la plage du nombre d'ensembles de données.
Ce que nous faisons avec la mise en œuvre suivante
.xml
directement sous ʻAnnotations dans
self._annopaths` sous forme de listeself._annopaths [index]
à partir de ʻindex donné par
getitem`est.
PyTorch
est utilisé. Le modèle pré-entraîné de PyTorch
entraîne l'image normalisée par ordre RVB,mean = (0.485, 0.456, 0.406)
,` std = (0.229, 0.224, 0.225) ʻen entrée. (Référence)ObjectDetectionDatasetBase
class ObjectDetectionDatasetBase(_DatasetBase):
def __init__(self, ignore=None, transform=None, target_transform=None, augmentation=None):
réduction
def __getitem__(self, index):
"""
:param index: int
:return:
img : rgb image(Tensor or ndarray)
targets : Tensor or ndarray of bboxes and labels [box, label]
= [xmin, ymin, xmamx, ymax, label index(or relu_one-hotted label)]
or
= [cx, cy, w, h, label index(or relu_one-hotted label)]
"""
img = self._get_image(index)
bboxes, linds, flags = self._get_bbox_lind(index)
img, bboxes, linds, flags = self.apply_transform(img, bboxes, linds, flags)
# concatenate bboxes and linds
if isinstance(bboxes, torch.Tensor) and isinstance(linds, torch.Tensor):
if linds.ndim == 1:
linds = linds.unsqueeze(1)
targets = torch.cat((bboxes, linds), dim=1)
else:
if linds.ndim == 1:
linds = linds[:, np.newaxis]
targets = np.concatenate((bboxes, linds), axis=1)
return img, targets
def apply_transform(self, img, bboxes, linds, flags):
"""
IMPORTATANT: apply transform function in order with ignore, augmentation, transform and target_transform
:param img:
:param bboxes:
:param linds:
:param flags:
:return:
Transformed img, bboxes, linds, flags
"""
# To Percent mode
height, width, channel = img.shape
# bbox = [xmin, ymin, xmax, ymax]
# [bbox[0] / width, bbox[1] / height, bbox[2] / width, bbox[3] / height]
bboxes[:, 0::2] /= float(width)
bboxes[:, 1::2] /= float(height)
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
return img, bboxes, linds, flags
VOCDatasetBase
VOC_class_labels = ['aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
VOC_class_nums = len(VOC_class_labels)
class VOCSingleDatasetBase(ObjectDetectionDatasetBase):
def __init__(self, voc_dir, focus, ignore=None, transform=None, target_transform=None, augmentation=None, class_labels=None):
"""
:param voc_dir: str, voc directory path above 'Annotations', 'ImageSets' and 'JPEGImages'
e.g.) voc_dir = '~~~~/trainval/VOCdevkit/voc2007'
:param focus: str, image set name. Assign txt file name under 'ImageSets' directory
:param ignore: target_transforms.Ignore
:param transform: instance of transforms
:param target_transform: instance of target_transforms
:param augmentation: instance of augmentations
:param class_labels: None or list or tuple, if it's None use VOC_class_labels
"""
super().__init__(ignore=ignore, transform=transform, target_transform=target_transform, augmentation=augmentation)
self._voc_dir = voc_dir
self._focus = focus
self._class_labels = _check_ins('class_labels', class_labels, (list, tuple), allow_none=True)
if self._class_labels is None:
self._class_labels = VOC_class_labels
layouttxt_path = os.path.join(self._voc_dir, 'ImageSets', 'Main', self._focus + '.txt')
if os.path.exists(layouttxt_path):
with open(layouttxt_path, 'r') as f:
filenames = f.read().splitlines()
filenames = [filename.split()[0] for filename in filenames]
self._annopaths = [os.path.join(self._voc_dir, 'Annotations', '{}.xml'.format(filename)) for filename in filenames]
else:
raise FileNotFoundError('layout: {} was invalid arguments'.format(focus))
@property
def class_nums(self):
return len(self._class_labels)
@property
def class_labels(self):
return self._class_labels
def _jpgpath(self, filename):
"""
:param filename: path containing .jpg
:return: path of jpg
"""
return os.path.join(self._voc_dir, 'JPEGImages', filename)
def __len__(self):
return len(self._annopaths)
"""
Detail of contents in voc > https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
VOC bounding box (xmin, ymin, xmax, ymax)
"""
def _get_image(self, index):
"""
:param index: int
:return:
rgb image(ndarray)
"""
root = ET.parse(self._annopaths[index]).getroot()
img = cv2.imread(self._jpgpath(_get_xml_et_value(root, 'filename')))
# pytorch's image order is rgb
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img.astype(np.float32)
def _get_bbox_lind(self, index):
"""
:param index: int
:return:
list of bboxes, list of bboxes' label index, list of flags([difficult, truncated,...])
"""
linds = []
bboxes = []
flags = []
root = ET.parse(self._annopaths[index]).getroot()
for obj in root.iter('object'):
linds.append(self._class_labels.index(_get_xml_et_value(obj, 'name')))
bndbox = obj.find('bndbox')
# bbox = [xmin, ymin, xmax, ymax]
bboxes.append([_get_xml_et_value(bndbox, 'xmin', int), _get_xml_et_value(bndbox, 'ymin', int), _get_xml_et_value(bndbox, 'xmax', int), _get_xml_et_value(bndbox, 'ymax', int)])
flags.append({'difficult': _get_xml_et_value(obj, 'difficult', int) == 1})#,
#'partial': _get_xml_et_value(obj, 'truncated', int) == 1})
return np.array(bboxes, dtype=np.float32), np.array(linds, dtype=np.float32), flags
La structure du répertoire est la même que celle de VOC en ce sens qu'il est divisé en annotations et images (ʻannotationset ʻimages / {train ou val} 20 **
), mais le traitement des annotations est légèrement différent.
├── annotations
│ ├── captions_train2014.json
│ ├── captions_val2014.json
│ ├── instances_train2014.json
│ ├── instances_val2014.json
│ ├── person_keypoints_train2014.json
│ └── person_keypoints_val2014.json
└── images
├── train2014
└── val2014
Comme vous pouvez le voir, contrairement à VOC, toutes les annotations sont écrites dans un seul fichier.
Et ce dont vous avez besoin pour la détection d'objets est le fichier ʻinstances_ {train ou val} 20 **. Json. Le format est décrit en détail dans [Officiel](http://cocodataset.org/#format-data). Et comme [python api](https://github.com/cocodataset/cocoapi) est préparé dans COCO, le fichier d'annotation pour la détection d'objet est ʻinstances_ {train ou val} 20 **. Json
. Si vous comprenez, honnêtement, vous n'avez pas besoin de bien comprendre le contenu.
Juste au cas où, quand je vérifie le format, ça ressemble à ça.
{
"info": info,
"images": [image],
"annotations": [annotation],
"licenses": [license],
}
info{
"year": int,
"version": str,
"description": str,
"contributor": str,
"url": str,
"date_created": datetime,
}
image{
"id": int,
"width": int,
"height": int,
"file_name": str,
"license": int,
"flickr_url": str,
"coco_url": str,
"date_captured": datetime,
}
license{
"id": int,
"name": str,
"url": str,
}
L '"annotation" et les "catégories" de détection d'objet sont les suivantes.
annotation{
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float, "bbox": [x,y,width,height],
"iscrowd": 0 or 1,
}
categories[{
"id": int,
"name": str,
"supercategory": str,
}]
Mettez-le simplement en œuvre comme VOC. Les informations nécessaires sont acquises via l'objet «COCO» de l'API.
COCO_class_labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
COCO_class_nums = len(COCO_class_labels)
COCO2014_ROOT = os.path.join(DATA_ROOT, 'coco', 'coco2014')
class COCOSingleDatasetBase(ObjectDetectionDatasetBase):
def __init__(self, coco_dir, focus, ignore=None, transform=None, target_transform=None, augmentation=None, class_labels=None):
"""
:param coco_dir: str, coco directory path above 'annotations' and 'images'
e.g.) coco_dir = '~~~~/coco2007/trainval'
:param focus: str or str, directory name under images
e.g.) focus = 'train2014'
:param ignore: target_transforms.Ignore
:param transform: instance of transforms
:param target_transform: instance of target_transforms
:param augmentation: instance of augmentations
:param class_labels: None or list or tuple, if it's None use VOC_class_labels
"""
super().__init__(ignore=ignore, transform=transform, target_transform=target_transform, augmentation=augmentation)
self._coco_dir = coco_dir
self._focus = focus
self._class_labels = _check_ins('class_labels', class_labels, (list, tuple), allow_none=True)
if self._class_labels is None:
self._class_labels = COCO_class_labels
self._annopath = os.path.join(self._coco_dir, 'annotations', 'instances_' + self._focus + '.json')
if os.path.exists(self._annopath):
self._coco = COCO(self._annopath)
else:
raise FileNotFoundError('json: {} was not found'.format('instances_' + self._focus + '.json'))
# remove no annotation image
self._imageids = list(self._coco.imgToAnns.keys())
@property
def class_nums(self):
return len(self._class_labels)
@property
def class_labels(self):
return self._class_labels
def _jpgpath(self, filename):
"""
:param filename: path containing .jpg
:return: path of jpg
"""
return os.path.join(self._coco_dir, 'images', self._focus, filename)
def __len__(self):
return len(self._imageids)
"""
Detail of contents in voc > https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
VOC bounding box (xmin, ymin, xmax, ymax)
"""
def _get_image(self, index):
"""
:param index: int
:return:
rgb image(ndarray)
"""
"""
self._coco.loadImgs(self._imageids[index]): list of dict, contains;
license: int
file_name: str
coco_url: str
height: int
width: int
date_captured: str
flickr_url: str
id: int
"""
filename = self._coco.loadImgs(self._imageids[index])[0]['file_name']
img = cv2.imread(self._jpgpath(filename))
# pytorch's image order is rgb
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img.astype(np.float32)
def _get_bbox_lind(self, index):
"""
:param index: int
:return:
list of bboxes, list of bboxes' label index, list of flags([difficult, truncated,...])
"""
linds = []
bboxes = []
flags = []
# anno_ids is list
anno_ids = self._coco.getAnnIds(self._imageids[index])
# annos is list of dict
annos = self._coco.loadAnns(anno_ids)
for anno in annos:
"""
anno's keys are;
segmentation: list of float
area: float
iscrowd: int, 0 or 1
image_id: int
bbox: list of float, whose length is 4
category_id: int
id: int
"""
"""
self._coco.loadCats(anno['category_id']) is list of dict, contains;
supercategory: str
id: int
name: str
"""
cat = self._coco.loadCats(anno['category_id'])[0]
linds.append(self.class_labels.index(cat['name']))
# bbox = [xmin, ymin, w, h]
xmin, ymin, w, h = anno['bbox']
# convert to corners
xmax, ymax = xmin + w, ymin + h
bboxes.append([xmin, ymin, xmax, ymax])
"""
flag = {}
keys = ['iscrowd']
for key in keys:
if key in anno.keys():
flag[key] = anno[key] == 1
else:
flag[key] = False
flags.append(flag)
"""
flags.append({'difficult': anno['iscrowd'] == 1})
return np.array(bboxes, dtype=np.float32), np.array(linds, dtype=np.float32), flags
Augmentation L'augmentation n'est pas toujours nécessaire, mais aussi dans Article original
Data augmentation is crucial
Cela semble important car il est mentionné comme. Dans l'article original, la méthode spécifique est omise, mais il existe deux principaux types de méthodes d'augmentation.
Dans ce qui suit, j'écrirai sur la façon dont cette image originale est augmentée.
Geometric Distortions Dans Distorsions géométriques, il existe les trois méthodes suivantes.
Random Expand
――Comme son nom l'indique, la taille est agrandie de manière aléatoire.
--Remplissez les marges pour l'expansion de la taille avec la valeur moyenne rgb_mean = (103.939, 116.779, 123.68)
utilisée dans la normalisation.
Random Sample
Échantillonner au hasard.
--A ce moment, le seuil du degré de chevauchement (valeur IoU) entre l'image échantillonnée et la boîte englobante est déterminé de manière aléatoire.
--Un des (0.1,0.3,0.5,0.7,0.9, Aucun)
"Aucun" n'a pas de seuil. Cependant, ʻIoU = 0` sans chevauchement est exclu. --Répétez jusqu'à ce que l'image échantillon dépasse le seuil.
Je vais omettre la mise en œuvre pendant un moment. Plus précisément, c'est ici. Voir ici pour d'autres exemples d'images.
Photometric Distortions Avec les distorsions photométriques, il existe les cinq méthodes suivantes.
Je vais omettre la mise en œuvre pendant un moment. Plus précisément, c'est ici. Voir ici pour d'autres exemples d'images.
Transform
Il s'agit du prétraitement de l'image d'entrée.
ndarray
en torch.Tensor
-Convertir de $ [0,255] $ à $ [0,1] $Le traitement dans ↑ est implémenté comme suit. PyTorch
a une fonction de traitement transforms pour l'image d'entrée prétraitée, mais l'image PIL
Parce que c'est une fonction pour, je l'ai fait pour Opencv. Si vous créez vos propres transformations
, vous devez correspondre au traitement dans la méthode de classe de l'ensemble de données ci-dessus. Cette fois, la méthode _apply_transform
transmet ʻimg,
bboxes,
linds,
flags, c'est-à-dire des informations de drapeau telles que l'image, la boîte englobante, l'étiquette et difficile comme arguments comme indiqué ci-dessous. (* Au fait, je l'ai omis, mais ʻaugmentation
a la même méthode d'implémentation.)
_apply_méthode de transformation
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
Par conséquent, vous devez implémenter la méthode __call__ (self, img, bboxes, linds, flags)
.
La classe ToTensor
qui se convertit en torch.Tensor
a l'ordre de l'image d'entrée de Conv2d
de PyTorch
à partir de l'ordre (h, w, c)
d'OpenCV(b, c, h, w). Converti en
.
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, bboxes, labels, flags):
for t in self.transforms:
img, bboxes, labels, flags = t(img, bboxes, labels, flags)
return img, bboxes, labels, flags
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ToTensor(object):
"""
Note that convert ndarray to tensor and [0-255] to [0-1]
"""
def __call__(self, img, *args):
# convert ndarray into Tensor
# transpose img's tensor (h, w, c) to pytorch's format (c, h, w). (num, c, h, w)
img = np.transpose(img, (2, 0, 1))
return (torch.from_numpy(img).float() / 255., *args)
class Resize(object):
def __init__(self, size):
"""
:param size: 2d-array-like, (height, width)
"""
self._size = size
def __call__(self, img, *args):
return (cv2.resize(img, self._size), *args)
class Normalize(object):
#def __init__(self, rgb_means=(103.939, 116.779, 123.68), rgb_stds=(1.0, 1.0, 1.0)):
def __init__(self, rgb_means=(0.485, 0.456, 0.406), rgb_stds=(0.229, 0.224, 0.225)):
self.means = np.array(rgb_means, dtype=np.float32).reshape((-1, 1, 1))
if np.any(np.abs(self.means) > 1):
logging.warning("In general, mean value should be less than 1 because img's range is [0-1]")
self.stds = np.array(rgb_stds, dtype=np.float32).reshape((-1, 1, 1))
def __call__(self, img, *args):
if isinstance(img, torch.Tensor):
return ((img.float() - torch.from_numpy(self.means)) / torch.from_numpy(self.stds), *args)
else:
return ((img.astype(np.float32) - self.means) / self.stds, *args)
Exemple d'utilisation
from data import transforms
transform = transforms.Compose(
[transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(rgb_means=(0.485, 0.456, 0.406), rgb_stds=(0.229, 0.224, 0.225))]
)
Target transform
Boîte englobante et conversion d'étiquette.
--Convertir la boîte englobante de la représentation des coins à la représentation des centres de gravité
--Convertir l'étiquette en vecteur Onehot
--Convertir de ndarray
à torche.Tensor
--Bounding box et étiquette concatenate
(shape = (box num, 4 = (cx, cy, w, h) + class_nums + 1 = (background))
)
Il ya trois.
représentation centroïdes
--Utilisez la coordonnée centrale $ (c_x, c_y) $ et la largeur / hauteur $ (w, h)
expression des coins
--Utilisez la coordonnée supérieure gauche $ (x_ {min}, y_ {min}) $ et la coordonnée inférieure droite $ (x_ {max}, y_ {max})
expression minmax --Utilisez la coordonnée supérieure gauche $ (x_ {min}, y_ {min}) $ et la coordonnée inférieure droite $ (x_ {max}, y_ {max}) $.
L'ordre est différent de la notation des coins.
--Coordonnée centrale $ (c_x, c_y) $ et largeur / hauteur $ (w, h) $, coordonnée supérieure gauche $ (x_ {min}, y_ {min}) $ et coordonnée inférieure droite $ (x_ {max} , y_ {max}) $ relation
\begin{align}
(c_x,c_y) &= (\frac{x_{min}+x_{max}}{2},\frac{y_{min}+y_{max}}{2}) \\
(w,h) &= (x_{max}-x_{min},y_{max}-y_{min})
\end{align}
Le processus de ↑ est implémenté comme suit. Le traitement correct des étiquettes pour la détection d'objet target_transforms
n'existe pas dans PyTorch
, vous devez donc créer vous-même target_transforms. Encore une fois, la méthode _apply_transform
transmet les informations d'indicateur telles que bboxes
, linds
, flags
, c'est-à-dire la zone de délimitation, l'étiquette et difficile comme arguments comme indiqué ci-dessous.
_apply_méthode de transformation
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
Par conséquent, vous pouvez implémenter la méthode __call__ (self, bboxes, linds, flags)
.
class ToTensor(object):
def __call__(self, bboxes, labels, flags):
return torch.from_numpy(bboxes), torch.from_numpy(labels), flags
class ToCentroids(object):
def __call__(self, bboxes, labels, flags):
# bbox = [xmin, ymin, xmax, ymax]
bboxes = np.concatenate(((bboxes[:, 2:] + bboxes[:, :2]) / 2,
(bboxes[:, 2:] - bboxes[:, :2])), axis=1)
return bboxes, labels, flags
class ToCorners(object):
def __call__(self, bboxes, labels, flags):
# bbox = [cx, cy, w, h]
bboxes = np.concatenate((bboxes[:, :2] - bboxes[:, 2:]/2,
bboxes[:, :2] + bboxes[:, 2:]/2), axis=1)
return bboxes, labels, flags
class OneHot(object):
def __init__(self, class_nums, add_background=True):
self._class_nums = class_nums
self._add_background = add_background
if add_background:
self._class_nums += 1
def __call__(self, bboxes, labels, flags):
if labels.ndim != 1:
raise ValueError('labels might have been already relu_one-hotted or be invalid shape')
labels = _one_hot_encode(labels.astype(np.int), self._class_nums)
labels = np.array(labels, dtype=np.float32)
return bboxes, labels, flags
Exemple d'utilisation
target_transform = target_transforms.Compose(
[target_transforms.ToCentroids(),
target_transforms.OneHot(class_nums=datasets.VOC_class_nums, add_background=True),
target_transforms.ToTensor()]
)
Le traitement de l'ensemble de données ressemble à ceci. Comme d'habitude, j'étais à mi-chemin, mais je ne pense pas qu'il y ait beaucoup d'articles sur le traitement des ensembles de données, j'espère donc que vous le trouverez utile.
Recommended Posts