[PYTHON] [Summary] Object detection model "End-to-End Object Detection with Transformers" using Transformer


I was curious about "End-to-End Object Detection with Transformers" (DETR), so I read the paper and checked the operation a little. Keep a brief record. [Papers, Github ]

What is DETR (summary)

・ Model released by Facebook AI Research in May this year

・ First use of Transformer, which is famous in the field of natural language processing, for object detection

-Simple network configuration of CNN + Transformer as shown in the figure below

-Achieves "End-to-End" object detection by eliminating parts that require manual adjustment, such as the default values of NMS and AnchorBox.

・ Insist on the effects of "Bipartite Matching Loss" and "Parallel Decoding" as points to realize the above.

・ Applicable not only to object detection but also to segmentation tasks


Inference code example

I will quote the code from the paper. As shown below, the process from model definition to inference processing can be written simply in about 40 lines.

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):

    def __init__(self, num_classes, hidden_dim, nheads,
                 num_encoder_layers, num_decoder_layers):
        # We take only convolutional layers from ResNet-50 model
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        self.transformer = nn.Transformer(hidden_dim, nheads,
                                          num_encoder_layers, num_decoder_layers)
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        x = self.backbone(inputs)
        h = self.conv(x)
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
        return self.linear_class(h), self.linear_bbox(h).sigmoid()

detr = DETR(num_classes=91, hidden_dim=256, nheads=8,
            num_encoder_layers=6, num_decoder_layers=6)
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)

Operation check

I actually checked the operation using the trained model. The environment is as follows. ・ OS: Ubuntu 18.04.4 LTS ・ GPU: GeForce RTX 2060 SUPER (8GB) x1 ・ PyTorch 1.5.1 / torchvision 0.6.0

For the implementation of the detection process, we referred to the original detr_demo.ipynb and detected the image captured from the webcam with OpenCV. The model used is a ResNet-50 based DETR.

The following is the actual detection result. It was confirmed that the detection was executed normally. (Although things that are not included in the COCO class are also targeted for shooting) frame_290.jpg

In my environment, the inference process itself seemed to run in about 45 msec (about 22 FPS).


I studied DETR and confirmed its operation. Since it is a new type of method, I think it will develop further in terms of accuracy and speed. Transformer had the image of being dedicated to natural language processing, but recently it has been introduced into models that handle images. I would like to study more in the future. (I also want to run Image GPT.)


・ "DETR" Transformer object detection debut https://medium.com/lsc-psd/detr-transformer%E3%81%AE%E7%89%A9%E4%BD%93%E6%A4%9C%E5%87%BA%E3%83%87%E3%83%93%E3%83%A5%E3%83%BC-dc18e582dec1 ・ Explanation of End-to-End Object Detection with Transformers (DETR) https://qiita.com/sasgawy/items/61fb64d848df9f6b53d1 -Adopt Transformer for object detection! Detailed explanation of the topic DETR! https://deepsquare.jp/2020/07/detr/

Recommended Posts

[Summary] Object detection model "End-to-End Object Detection with Transformers" using Transformer
Logo detection using TensorFlow Object Detection API
[Python] Using OpenCV with Python (Edge Detection)
Cat detection with OpenCV (model distribution)
How to make a model for object detection using YOLO in 3 hours
Object detection using Jetson Nano (YOLOv3)-(1) Jetson Nano settings-
Try Object detection with Raspberry Pi 4 + Coral
Using MLflow with Databricks ③ --Model lifecycle management -
Cooking object detection with yolo + image classification
[Python] Real-time object detection with iPad camera
I tried object detection using Python and OpenCV
Let's try real-time object detection using Faster R-CNN