[PYTHON] Convert RGB and HSV in a differentiable form with PyTorch

I couldn't find it unexpectedly, so it is written in HSV Color Space-Wikipedia. Implemented based on the expression.

If you just want to convert RGB and HSV, you can use other libraries (PIL, OpenCV, etc.), but you can't differentiate with that, so you can't incorporate it inside the neural network. So I wrote a function to convert RGB and HSV to each other using only PyTorch functions.

code

color_convert.py


import torch


def rgb2hsv(input, epsilon=1e-10):
    assert(input.shape[1] == 3)

    r, g, b = input[:, 0], input[:, 1], input[:, 2]
    max_rgb, argmax_rgb = input.max(1)
    min_rgb, argmin_rgb = input.min(1)

    max_min = max_rgb - min_rgb + epsilon

    h1 = 60.0 * (g - r) / max_min + 60.0
    h2 = 60.0 * (b - g) / max_min + 180.0
    h3 = 60.0 * (r - b) / max_min + 300.0

    h = torch.stack((h2, h3, h1), dim=0).gather(dim=0, index=argmin_rgb.unsqueeze(0)).squeeze(0)
    s = max_min / (max_rgb + epsilon)
    v = max_rgb

    return torch.stack((h, s, v), dim=1)


def hsv2rgb(input):
    assert(input.shape[1] == 3)

    h, s, v = input[:, 0], input[:, 1], input[:, 2]
    h_ = (h - torch.floor(h / 360) * 360) / 60
    c = s * v
    x = c * (1 - torch.abs(torch.fmod(h_, 2) - 1))

    zero = torch.zeros_like(c)
    y = torch.stack((
        torch.stack((c, x, zero), dim=1),
        torch.stack((x, c, zero), dim=1),
        torch.stack((zero, c, x), dim=1),
        torch.stack((zero, x, c), dim=1),
        torch.stack((x, zero, c), dim=1),
        torch.stack((c, zero, x), dim=1),
    ), dim=0)

    index = torch.repeat_interleave(torch.floor(h_).unsqueeze(1), 3, dim=1).unsqueeze(0).to(torch.long)
    rgb = (y.gather(dim=0, index=index) + (v - c)).squeeze(0)
    return rgb

How to use

Both rgb2hsv and hsv2rgb take images in a mini-batch of images (NCHW format) as input. The range of H (hue) is from 0 to 360 (the outside of the range loops). The range of RGB and SV is 0 to 1.

Recommended Posts

Convert RGB and HSV in a differentiable form with PyTorch
How to convert / restore a string with [] in python
Train MNIST data with a neural network in PyTorch
How to get RGB and HSV histograms in OpenCV
Note installing modules such as pytorch and opencv with pip in Blender python (2.82a or later)
I made a Nyanko tweet form with Python, Flask and Heroku
Draw a watercolor illusion with edge detection in Python3 and openCV3
A collection of tips for speeding up learning and reasoning with PyTorch
Get and convert the current time in the system local timezone with python
Draw a heart in Ruby with PyCall
A memo with Python2.7 and Python3 on CentOS
Make a drawing quiz with kivy + PyTorch
Dealing with "years and months" in Python
Implement a model with state and behavior
Function to extract the maximum and minimum values ​​in a slice with Go
Format DataFrame data with Pytorch into a form that can be trained with NN
Instantly convert Model to Dictionary with Django and initialize Form at explosive speed