[PYTHON] I get an error ~ is zero, singular U when passing the covariance matrix from the Linear layer to MultivariateNormal

What i did

pytorch uses v1.5.1. Let the output of Linear be (x, y, vxx, vyy, vxy), create a variance-covariance matrix, and pass it to MultivariateNormal.

fc = nn.Linear(n, 5)
output = fc(x)
mean = output[:2]
vxx, vyy = nn.Softplus()(output[2:4])
vxy = output[-1]

covariance_matrix = torch.zeros(2, 2)
covariance_matrix[0, 0] += vxx
covariance_matrix[0, 1] += vxy
covariance_matrix[1, 0] += vxy
covariance_matrix[1, 1] += vyy

dist = MultivariateNormal(mean, covariance_matrix)

RuntimeError: cholesky_cuda: For batch 0: U(6,6) is zero, singular U. In MultivariateNormal, the variance-covariance matrix is Cholesky decomposed, so it is necessary to give a definite matrix. As it is, the above error occurs because it is not guaranteed that covariance_matrix is a definite matrix.

Solution: Pass the lower triangular matrix after Cholesky decomposition.

fc = nn.Linear(n, 5)
output = fc(x)
mean = output[:2]
a, c = nn.Softplus()(output[2:4])
b = output[-1]

L = torch.zeros(2, 2)
L[0, 0] += a
L[1, 0] += b
L[1, 1] += c

dist = MultivariateNormal(mean,  scale_tril=L)

scale_tril (Tensor) – lower-triangular factor of covariance, with positive-valued diagonal

Therefore, the diagonal components $ a and c $ are set to positive values by softplus.

Cholesky decomposition when making a covariance matrix

\Sigma = LL^{T}

in accordance with,

covariance_matrix = np.dot(L, L.T)

And it is sufficient.

Recommended Posts

I get an error ~ is zero, singular U when passing the covariance matrix from the Linear layer to MultivariateNormal
If you get the error "basis matrix is singular to working precision" in GLPK
I get an error when trying to install maec 4.0.1.0 with pip
When I get an error with PyInstaller
[Python] I want to know the variables in the function when an error occurs!
In the Chainer tutorial, I get an error when importing a package. (mock)
When I try to import pandas on macOS I get the error No module named'_bz2'
I get an error when I try to raise Python to 3 series using pyenv on Catalina
I get an error when I put opencv in pyautoGUI
I get an error when trying meinheld + WebSocket + mongodb
I get [Error 2055] when trying to connect to MySQL on Heroku
Keras I want to get the output of any layer !!
I tried to get various information from the codeforces API
I got an error when I tried to process luigi in parallel on windows, but the solution
When I start service with systemd, I get command not found even though the path is passing
About the error I encountered when trying to use Adafruit_DHT from Python on a Raspberry Pi
When I get an error with Pylint in Atom on Windows
What to do when you get "I can't see the site !!!!"
[Question] In sk-learn random forest regression, an error occurs when the number of parallels is set to -1.
[Django] What to do if an Integrity Error occurs when registering data from the management site to the database
When I import TensorFlow to Python, I get "Import Error: DLL load failed: The specified module cannot be found."
I get an error when I put a Python plugin in Visual Studio Code under the pyenv environment