Wie der Titel schon sagt, habe ich Single Shot Multibox Detector (SSD) mit PyTorch
([https://github.com] implementiert). / jjjkkkjjj / pytorch_SSD](https://github.com/jjjkkkjjj/pytorch_SSD)) Die Berechnung wird jedoch im Vergleich zu ssd.pytorch usw. berechnet. Es ist langsam (ich werde die Ursache später untersuchen). </ strike> [^ 1] Ich habe jedoch hart an der Abstraktion gearbeitet, daher denke ich, dass sie sehr anpassbar ist. (Ich weiß nicht, ob ich dokumentieren kann, wie man es richtig benutzt ...). Wenn Sie die Implementierung von SSD überprüfen, werden Sie viel finden, aber diesmal
[^ 1]: Ich habe das Initialisierungsargument num_workers
von DataLoader
einfach nicht gesetzt. .. .. Derzeit ist es so schnell wie ssd.pytorch.
――Ich wollte es frei anpassen
Ich habe es aus diesem Grund implementiert. In vielen Artikeln gibt es viele leicht verständliche Erklärungen (Referenz), aber ich möchte sie auf meine eigene Weise zusammenfassen, um meinen Geist zu organisieren.
PyTorch Zunächst werde ich kurz auf das diesmal verwendete Deep-Learning-Framework "PyTorch" eingehen. Als ich zum ersten Mal versuchte, SSD zu implementieren, habe ich Tensorflow verwendet. Jedoch,
compatible.v1
, compatible.v2
usw.?)Daher verlief die Implementierung nicht einfach. Insbesondere "schwer zu debuggen" war für mich fatal und ich konnte die Übersetzung nicht gut verstehen. Nun, ich denke, es wäre nützlich zu lernen, wie man es benutzt, nur weil ich "Tensorflow" nicht verstanden habe. .. ..
Ich dachte, dass es für den Rest meines Lebens nicht fertig sein würde, also änderte ich es in "Pytorch", das die Eigenschaft hat, Operationen ähnlich wie "Numpy" ausführen zu können. Als ich zu "PyTorch" wechselte, wurde die "Schwierigkeit des Debuggens", die ich in Tensorflow empfand, erheblich verbessert, und die Implementierung verlief reibungslos. Nun, die Operation von Numpy
ist Matft (https://github.com/jjjkkkjjj/Matft), [Ich habe eine N-dimensionale Matrixberechnungsbibliothek Matft mit Swift erstellt] Vertraut mit der Implementierung von (https://qiita.com/jjjkkkjjj/items/1f2b5c3835b1600d3129)) und MILES (https://github.com/jjjkkkjjj/MIL Weil es so war, war "PyTorch" perfekt für mich.
Zunächst möchte ich kurz auf die SSD eingehen und sie dann im Detail zusammenfassen. SSD ist ein Objekterkennungsalgorithmus, der die Position und Beschriftung eines Objekts von Ende zu Ende vorhersagen kann. Es ist eine geeignete Abbildung, aber wenn Sie ein Eingabebild wie dieses angeben, gibt die SSD die Position und Beschriftung des Objekts sofort aus.
Was dieses Modell tut und tut, ist wie folgt. Ich werde Schritt für Schritt erklären.
Im Original-SSD-Dokument lauten die Datensätze PASCAL VOC2007, [PASCAL VOC2012](http: //host.robots.ox). .ac.uk / pascal / VOC / voc2012 /) und COCO2014 werden verwendet. COCO ist noch nicht implementiert, daher werde ich den VOC-Datensatz erläutern. </ strike> Lassen Sie uns zunächst über den VOC-Datensatz sprechen.
Die Struktur des Verzeichnisses ist im Grunde genommen einheitlich und sieht wie folgt aus.
Voc-Verzeichnis
$ tree -I '*.png|*.xml|*.jpg|*.txt'
└── VOCdevkit
└── VOC20**
├── Annotations
├── ImageSets
│ ├── Action
│ ├── Layout
│ ├── Main
│ └── Segmentation
├── JPEGImages
├── SegmentationClass
└── SegmentationObject
Was Sie für die Objekterkennung benötigen, sind "Annotations", "JPEGImages" und "ImageSets / Main" direkt unter "VOC20 **". Jeder ist wie folgt.
.jpeg
-Datei von JPEGImages
..txt
-Datei, die Informationen zum Datensatz enthält. Die Dateinamen von "Annotation" und "JPEGImages", die die Elemente des Sets sind, werden beschrieben..xml
-Datei)Die .xml
-Datei im Annotaions
-Verzeichnis lautet wie folgt.
Annotations/~~.xml
<annotation>
<folder>VOC2007</folder>
<filename>000005.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
<flickrid>325991873</flickrid>
</source>
<owner>
<flickrid>archintent louisville</flickrid>
<name>?</name>
</owner>
<size>
<width>500</width>
<height>375</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>chair</name>
<pose>Rear</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>263</xmin>
<ymin>211</ymin>
<xmax>324</xmax>
<ymax>339</ymax>
</bndbox>
</object>
<object>
<name>chair</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>165</xmin>
<ymin>264</ymin>
<xmax>253</xmax>
<ymax>372</ymax>
</bndbox>
</object>
...
</annotation>
Die wichtigen Punkte sind wie folgt.
<filename>
<object>
<name>
--Markenname<truncated>
0
) oder teilweise sichtbar ( 1
) ist.
<difficult>
1
) oder nicht ( 0
)
<bndbox>
Um ein Dataset zu implementieren, müssen Sie die Klasse "Dataset" erweitern. Dann ist es notwendig, "len" zu implementieren, das die Anzahl der Datensätze zurückgibt, und "getitem", das die Eingabedaten und die korrekte Antwortbezeichnung für "index" im Bereich der Anzahl der Datensätze zurückgibt.
Was wir mit der folgenden Implementierung machen
.xml
-Dateien direkt unter Annotations
in self._annopaths
als Listeist.
PyTorch
vertriebene VGG Pre-Trained-Modell verwendet wird. Das vorab trainierte Modell von "PyTorch" trainiert das durch die RGB-Reihenfolge normalisierte Bild "Mittelwert = (0,485, 0,456, 0,406)", "Standard = (0,229, 0,224, 0,225)" als Eingabe. (Referenz)ObjectDetectionDatasetBase
class ObjectDetectionDatasetBase(_DatasetBase):
def __init__(self, ignore=None, transform=None, target_transform=None, augmentation=None):
Kürzung
def __getitem__(self, index):
"""
:param index: int
:return:
img : rgb image(Tensor or ndarray)
targets : Tensor or ndarray of bboxes and labels [box, label]
= [xmin, ymin, xmamx, ymax, label index(or relu_one-hotted label)]
or
= [cx, cy, w, h, label index(or relu_one-hotted label)]
"""
img = self._get_image(index)
bboxes, linds, flags = self._get_bbox_lind(index)
img, bboxes, linds, flags = self.apply_transform(img, bboxes, linds, flags)
# concatenate bboxes and linds
if isinstance(bboxes, torch.Tensor) and isinstance(linds, torch.Tensor):
if linds.ndim == 1:
linds = linds.unsqueeze(1)
targets = torch.cat((bboxes, linds), dim=1)
else:
if linds.ndim == 1:
linds = linds[:, np.newaxis]
targets = np.concatenate((bboxes, linds), axis=1)
return img, targets
def apply_transform(self, img, bboxes, linds, flags):
"""
IMPORTATANT: apply transform function in order with ignore, augmentation, transform and target_transform
:param img:
:param bboxes:
:param linds:
:param flags:
:return:
Transformed img, bboxes, linds, flags
"""
# To Percent mode
height, width, channel = img.shape
# bbox = [xmin, ymin, xmax, ymax]
# [bbox[0] / width, bbox[1] / height, bbox[2] / width, bbox[3] / height]
bboxes[:, 0::2] /= float(width)
bboxes[:, 1::2] /= float(height)
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
return img, bboxes, linds, flags
VOCDatasetBase
VOC_class_labels = ['aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor']
VOC_class_nums = len(VOC_class_labels)
class VOCSingleDatasetBase(ObjectDetectionDatasetBase):
def __init__(self, voc_dir, focus, ignore=None, transform=None, target_transform=None, augmentation=None, class_labels=None):
"""
:param voc_dir: str, voc directory path above 'Annotations', 'ImageSets' and 'JPEGImages'
e.g.) voc_dir = '~~~~/trainval/VOCdevkit/voc2007'
:param focus: str, image set name. Assign txt file name under 'ImageSets' directory
:param ignore: target_transforms.Ignore
:param transform: instance of transforms
:param target_transform: instance of target_transforms
:param augmentation: instance of augmentations
:param class_labels: None or list or tuple, if it's None use VOC_class_labels
"""
super().__init__(ignore=ignore, transform=transform, target_transform=target_transform, augmentation=augmentation)
self._voc_dir = voc_dir
self._focus = focus
self._class_labels = _check_ins('class_labels', class_labels, (list, tuple), allow_none=True)
if self._class_labels is None:
self._class_labels = VOC_class_labels
layouttxt_path = os.path.join(self._voc_dir, 'ImageSets', 'Main', self._focus + '.txt')
if os.path.exists(layouttxt_path):
with open(layouttxt_path, 'r') as f:
filenames = f.read().splitlines()
filenames = [filename.split()[0] for filename in filenames]
self._annopaths = [os.path.join(self._voc_dir, 'Annotations', '{}.xml'.format(filename)) for filename in filenames]
else:
raise FileNotFoundError('layout: {} was invalid arguments'.format(focus))
@property
def class_nums(self):
return len(self._class_labels)
@property
def class_labels(self):
return self._class_labels
def _jpgpath(self, filename):
"""
:param filename: path containing .jpg
:return: path of jpg
"""
return os.path.join(self._voc_dir, 'JPEGImages', filename)
def __len__(self):
return len(self._annopaths)
"""
Detail of contents in voc > https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
VOC bounding box (xmin, ymin, xmax, ymax)
"""
def _get_image(self, index):
"""
:param index: int
:return:
rgb image(ndarray)
"""
root = ET.parse(self._annopaths[index]).getroot()
img = cv2.imread(self._jpgpath(_get_xml_et_value(root, 'filename')))
# pytorch's image order is rgb
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img.astype(np.float32)
def _get_bbox_lind(self, index):
"""
:param index: int
:return:
list of bboxes, list of bboxes' label index, list of flags([difficult, truncated,...])
"""
linds = []
bboxes = []
flags = []
root = ET.parse(self._annopaths[index]).getroot()
for obj in root.iter('object'):
linds.append(self._class_labels.index(_get_xml_et_value(obj, 'name')))
bndbox = obj.find('bndbox')
# bbox = [xmin, ymin, xmax, ymax]
bboxes.append([_get_xml_et_value(bndbox, 'xmin', int), _get_xml_et_value(bndbox, 'ymin', int), _get_xml_et_value(bndbox, 'xmax', int), _get_xml_et_value(bndbox, 'ymax', int)])
flags.append({'difficult': _get_xml_et_value(obj, 'difficult', int) == 1})#,
#'partial': _get_xml_et_value(obj, 'truncated', int) == 1})
return np.array(bboxes, dtype=np.float32), np.array(linds, dtype=np.float32), flags
Die Struktur des Verzeichnisses ist insofern dieselbe wie bei VOC, als es in Anmerkungen und Bilder ("Anmerkungen" und "Bilder / {train oder val} 20 **") unterteilt ist, die Behandlung von Anmerkungen unterscheidet sich jedoch geringfügig.
├── annotations
│ ├── captions_train2014.json
│ ├── captions_val2014.json
│ ├── instances_train2014.json
│ ├── instances_val2014.json
│ ├── person_keypoints_train2014.json
│ └── person_keypoints_val2014.json
└── images
├── train2014
└── val2014
Wie Sie sehen, werden im Gegensatz zu VOC alle Anmerkungen in einer Datei geschrieben.
Und was Sie für die Objekterkennung benötigen, ist die Datei instance_ {train or val} 20 **. Json
.
Das Format ist ausführlich in Official beschrieben. Und da COCO über python api verfügt, lautet die Anmerkungsdatei für die Objekterkennung "instance_ {train or val} 20 **. Json". Wenn Sie ehrlich verstehen, müssen Sie den Inhalt nicht sehr gut verstehen.
Nur für den Fall, wenn ich das Format überprüfe, sieht es so aus.
{
"info": info,
"images": [image],
"annotations": [annotation],
"licenses": [license],
}
info{
"year": int,
"version": str,
"description": str,
"contributor": str,
"url": str,
"date_created": datetime,
}
image{
"id": int,
"width": int,
"height": int,
"file_name": str,
"license": int,
"flickr_url": str,
"coco_url": str,
"date_captured": datetime,
}
license{
"id": int,
"name": str,
"url": str,
}
Die "Annotation" und "Kategorie" der Objekterkennung sind wie folgt.
annotation{
"id": int,
"image_id": int,
"category_id": int,
"segmentation": RLE or [polygon],
"area": float, "bbox": [x,y,width,height],
"iscrowd": 0 or 1,
}
categories[{
"id": int,
"name": str,
"supercategory": str,
}]
Implementieren Sie es einfach wie VOC. Die notwendigen Informationen werden über das COCO-Objekt der API erfasst.
COCO_class_labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
COCO_class_nums = len(COCO_class_labels)
COCO2014_ROOT = os.path.join(DATA_ROOT, 'coco', 'coco2014')
class COCOSingleDatasetBase(ObjectDetectionDatasetBase):
def __init__(self, coco_dir, focus, ignore=None, transform=None, target_transform=None, augmentation=None, class_labels=None):
"""
:param coco_dir: str, coco directory path above 'annotations' and 'images'
e.g.) coco_dir = '~~~~/coco2007/trainval'
:param focus: str or str, directory name under images
e.g.) focus = 'train2014'
:param ignore: target_transforms.Ignore
:param transform: instance of transforms
:param target_transform: instance of target_transforms
:param augmentation: instance of augmentations
:param class_labels: None or list or tuple, if it's None use VOC_class_labels
"""
super().__init__(ignore=ignore, transform=transform, target_transform=target_transform, augmentation=augmentation)
self._coco_dir = coco_dir
self._focus = focus
self._class_labels = _check_ins('class_labels', class_labels, (list, tuple), allow_none=True)
if self._class_labels is None:
self._class_labels = COCO_class_labels
self._annopath = os.path.join(self._coco_dir, 'annotations', 'instances_' + self._focus + '.json')
if os.path.exists(self._annopath):
self._coco = COCO(self._annopath)
else:
raise FileNotFoundError('json: {} was not found'.format('instances_' + self._focus + '.json'))
# remove no annotation image
self._imageids = list(self._coco.imgToAnns.keys())
@property
def class_nums(self):
return len(self._class_labels)
@property
def class_labels(self):
return self._class_labels
def _jpgpath(self, filename):
"""
:param filename: path containing .jpg
:return: path of jpg
"""
return os.path.join(self._coco_dir, 'images', self._focus, filename)
def __len__(self):
return len(self._imageids)
"""
Detail of contents in voc > https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5
VOC bounding box (xmin, ymin, xmax, ymax)
"""
def _get_image(self, index):
"""
:param index: int
:return:
rgb image(ndarray)
"""
"""
self._coco.loadImgs(self._imageids[index]): list of dict, contains;
license: int
file_name: str
coco_url: str
height: int
width: int
date_captured: str
flickr_url: str
id: int
"""
filename = self._coco.loadImgs(self._imageids[index])[0]['file_name']
img = cv2.imread(self._jpgpath(filename))
# pytorch's image order is rgb
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img.astype(np.float32)
def _get_bbox_lind(self, index):
"""
:param index: int
:return:
list of bboxes, list of bboxes' label index, list of flags([difficult, truncated,...])
"""
linds = []
bboxes = []
flags = []
# anno_ids is list
anno_ids = self._coco.getAnnIds(self._imageids[index])
# annos is list of dict
annos = self._coco.loadAnns(anno_ids)
for anno in annos:
"""
anno's keys are;
segmentation: list of float
area: float
iscrowd: int, 0 or 1
image_id: int
bbox: list of float, whose length is 4
category_id: int
id: int
"""
"""
self._coco.loadCats(anno['category_id']) is list of dict, contains;
supercategory: str
id: int
name: str
"""
cat = self._coco.loadCats(anno['category_id'])[0]
linds.append(self.class_labels.index(cat['name']))
# bbox = [xmin, ymin, w, h]
xmin, ymin, w, h = anno['bbox']
# convert to corners
xmax, ymax = xmin + w, ymin + h
bboxes.append([xmin, ymin, xmax, ymax])
"""
flag = {}
keys = ['iscrowd']
for key in keys:
if key in anno.keys():
flag[key] = anno[key] == 1
else:
flag[key] = False
flags.append(flag)
"""
flags.append({'difficult': anno['iscrowd'] == 1})
return np.array(bboxes, dtype=np.float32), np.array(linds, dtype=np.float32), flags
Augmentation Eine Erweiterung ist nicht immer erforderlich, sondern auch in Originalartikel.
Data augmentation is crucial
Es scheint wichtig zu sein, weil es als erwähnt wird. In der Originalarbeit wird die spezifische Methode weggelassen, es gibt jedoch zwei Haupttypen von Augmentationsmethoden.
Im Folgenden werde ich darüber schreiben, wie dieses Originalbild erweitert wird.
Geometric Distortions In Geometrische Verzerrungen gibt es die folgenden drei Methoden.
Random Expand
Wie der Name schon sagt, wird die Größe zufällig erweitert. --Füllen Sie die Ränder für die Größenerweiterung mit dem Durchschnittswert "rgb_mean = (103.939, 116.779, 123.68)", der bei der Normalisierung verwendet wird.
Random Sample
Stichprobe nach dem Zufallsprinzip.
Zu diesem Zeitpunkt wird der Schwellenwert für den Grad der Überlappung (IoU-Wert) zwischen dem abgetasteten Bild und dem Begrenzungsrahmen zufällig bestimmt.
Einer von "(0,1,0,3,0,5,0,7,0,9, keine)"
--None
hat keine Schwelle. "IoU = 0" ohne Überlappung ist jedoch ausgeschlossen.
Wiederholen Sie diesen Vorgang, bis das Beispielbild den Schwellenwert überschreitet.
Ich werde die Implementierung für einen Moment weglassen. Insbesondere ist es hier. Weitere Bildbeispiele finden Sie unter hier.
Photometric Distortions Bei photometrischen Verzerrungen gibt es die folgenden fünf Methoden.
Ich werde die Implementierung für einen Moment weglassen. Insbesondere ist es hier. Weitere Bildbeispiele finden Sie unter hier.
Transform
Dies ist die Vorverarbeitung des Eingabebildes.
Die Verarbeitung in ↑ wird wie folgt implementiert. PyTorch
hat eine Verarbeitungsfunktion transformiert für das vorverarbeitete Eingabebild, aber das PIL
-Bild Da es eine Funktion für ist, habe ich es für Opencv gemacht. Wenn Sie Ihre eigenen Transformationen erstellen, müssen Sie der Verarbeitung in der Klassenmethode des obigen Datensatzes entsprechen. Dieses Mal übergibt die Methode "_apply_transform" die Flag-Informationen wie "img", "bboxes", "linds", "flags", dh Bild, Begrenzungsrahmen, Beschriftung, schwierig usw. als Argumente, wie unten gezeigt. (* Übrigens habe ich es weggelassen, aber "Augmentation" hat die gleiche Implementierungsmethode.)
_apply_Transformationsmethode
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
Daher sollten Sie die Methode __call__ (self, img, bboxes, linds, flags)
implementieren.
Die "ToTensor" -Klasse, die in "torch.Tensor" konvertiert wird, hat die Reihenfolge des Eingabebildes von "Conv2d" von "PyTorch" aus der Reihenfolge "(h, w, c)" von OpenCV "(b, c, h, w). Umgerechnet in `.
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, bboxes, labels, flags):
for t in self.transforms:
img, bboxes, labels, flags = t(img, bboxes, labels, flags)
return img, bboxes, labels, flags
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ToTensor(object):
"""
Note that convert ndarray to tensor and [0-255] to [0-1]
"""
def __call__(self, img, *args):
# convert ndarray into Tensor
# transpose img's tensor (h, w, c) to pytorch's format (c, h, w). (num, c, h, w)
img = np.transpose(img, (2, 0, 1))
return (torch.from_numpy(img).float() / 255., *args)
class Resize(object):
def __init__(self, size):
"""
:param size: 2d-array-like, (height, width)
"""
self._size = size
def __call__(self, img, *args):
return (cv2.resize(img, self._size), *args)
class Normalize(object):
#def __init__(self, rgb_means=(103.939, 116.779, 123.68), rgb_stds=(1.0, 1.0, 1.0)):
def __init__(self, rgb_means=(0.485, 0.456, 0.406), rgb_stds=(0.229, 0.224, 0.225)):
self.means = np.array(rgb_means, dtype=np.float32).reshape((-1, 1, 1))
if np.any(np.abs(self.means) > 1):
logging.warning("In general, mean value should be less than 1 because img's range is [0-1]")
self.stds = np.array(rgb_stds, dtype=np.float32).reshape((-1, 1, 1))
def __call__(self, img, *args):
if isinstance(img, torch.Tensor):
return ((img.float() - torch.from_numpy(self.means)) / torch.from_numpy(self.stds), *args)
else:
return ((img.astype(np.float32) - self.means) / self.stds, *args)
Anwendungsbeispiel
from data import transforms
transform = transforms.Compose(
[transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(rgb_means=(0.485, 0.456, 0.406), rgb_stds=(0.229, 0.224, 0.225))]
)
Target transform
Konvertierung von Begrenzungsrahmen und Etiketten.
Dort sind drei.
Darstellung der Zentroide
--Verwenden Sie die Mittelkoordinate $ (c_x, c_y) $ und die Breite / Höhe $ (w, h)
Corners Ausdruck
--Verwenden Sie die obere linke Koordinate $ (x_ {min}, y_ {min}) $ und die untere rechte Koordinate $ (x_ {max}, y_ {max})
--minmax Ausdruck --Verwenden Sie die obere linke Koordinate $ (x_ {min}, y_ {min}) $ und die untere rechte Koordinate $ (x_ {max}, y_ {max}) $.
--Zentralkoordinate $ (c_x, c_y) $ und Breite / Höhe $ (w, h) $, obere linke Koordinate $ (x_ {min}, y_ {min}) $ und untere rechte Koordinate $ (x_ {max}) , y_ {max}) $ Beziehung
\begin{align}
(c_x,c_y) &= (\frac{x_{min}+x_{max}}{2},\frac{y_{min}+y_{max}}{2}) \\
(w,h) &= (x_{max}-x_{min},y_{max}-y_{min})
\end{align}
Der Prozess von ↑ wird wie folgt implementiert. Die korrekte Etikettenverarbeitung für die Objekterkennung target_transforms
ist in PyTorch
nicht vorhanden, daher müssen Sie target_transforms selbst erstellen. Wiederum übergibt die Methode "_apply_transform" die Flag-Informationen wie "bboxes", "linds", "flags", dh den Begrenzungsrahmen, die Bezeichnung und die schwierigen Argumente, wie unten gezeigt.
_apply_Transformationsmethode
if self.ignore:
bboxes, linds, flags = self.ignore(bboxes, linds, flags)
if self.augmentation:
img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)
if self.transform:
img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)
if self.target_transform:
bboxes, linds, flags = self.target_transform(bboxes, linds, flags)
Daher können Sie die Methode __call__ (self, bboxes, linds, flags)
implementieren.
class ToTensor(object):
def __call__(self, bboxes, labels, flags):
return torch.from_numpy(bboxes), torch.from_numpy(labels), flags
class ToCentroids(object):
def __call__(self, bboxes, labels, flags):
# bbox = [xmin, ymin, xmax, ymax]
bboxes = np.concatenate(((bboxes[:, 2:] + bboxes[:, :2]) / 2,
(bboxes[:, 2:] - bboxes[:, :2])), axis=1)
return bboxes, labels, flags
class ToCorners(object):
def __call__(self, bboxes, labels, flags):
# bbox = [cx, cy, w, h]
bboxes = np.concatenate((bboxes[:, :2] - bboxes[:, 2:]/2,
bboxes[:, :2] + bboxes[:, 2:]/2), axis=1)
return bboxes, labels, flags
class OneHot(object):
def __init__(self, class_nums, add_background=True):
self._class_nums = class_nums
self._add_background = add_background
if add_background:
self._class_nums += 1
def __call__(self, bboxes, labels, flags):
if labels.ndim != 1:
raise ValueError('labels might have been already relu_one-hotted or be invalid shape')
labels = _one_hot_encode(labels.astype(np.int), self._class_nums)
labels = np.array(labels, dtype=np.float32)
return bboxes, labels, flags
Anwendungsbeispiel
target_transform = target_transforms.Compose(
[target_transforms.ToCentroids(),
target_transforms.OneHot(class_nums=datasets.VOC_class_nums, add_background=True),
target_transforms.ToTensor()]
)
Die Datensatzverarbeitung sieht folgendermaßen aus. Wie üblich war ich auf halbem Weg, aber ich glaube nicht, dass es viele Artikel zur Datensatzverarbeitung gibt. Ich hoffe, Sie finden es hilfreich.
Recommended Posts