[PYTHON] How to learn structured SVM of ChainCRF with PyStruct

The documents and samples on the original web are unfriendly, so I tried using easy-to-understand data.

First preparation

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from pystruct.inference import inference_dispatch

The content is noise removal of time series data as in Implementing HMM with PyStruct. For learning, a time series with noise added to a fixed time series is used. (Aside from the fact that it's fixed so you don't have to infer)

Creation of training data

n_samples = 500

d = np.array([12, 12, 11, 11, 10,  9,  8,  8,  7,  6,  6,  6,  7,  8,  8,  8,  6,
        5,  4,  3,  3,  3,  2,  1,  0,  1,  3,  4,  5,  6,  8,  8,  9,  9,
       10, 11, 12, 13, 14, 14, 14, 15, 15, 15, 15])
n_nodes = d.shape[0]
n_states = np.unique(d).shape[0]
n_features = n_states + 1 # add bias

y = np.repeat(d[np.newaxis,:], n_samples, axis=0)

data = y + (np.random.rand(n_samples, n_nodes)-0.5)*5

# negative sign for maximization !
X = np.array( [ [ [ -abs(i-j)**0.1 for j in range(n_states)]  for i in dd ] for dd in data] )

# add constant features for bias
X = np.array( [np.hstack((X[i], 0.1*np.ones((X[i].shape[0],1)))) for i in range(X.shape[0])] )

Data X has 500 numbers, 45 time series lengths, 16 states / classes, and 17 features (SVM bias).

Check size

X.shape, y.shape
((500, 45, 17), (500, 45))

Divide learning and testing as usual

from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)

Check training data

fig, axes = plt.subplots(3,3, figsize=(20,6))
for ax in axes.ravel():
    ax.plot(data[c], label='data')
    ax.plot(y_train[c], label='true')
    c += 1
plt.legend(bbox_to_anchor=(1.1, 1.0), loc=2, borderaxespad=0.)


Comparison of training data X (features at each time) and y (true fixed time series) for confirmation.


plt.matshow(np.flipud(X_train[0,:,:-1].T)) # remove bias

plt.plot(15-y_train[0]) # flipud


Now, prepare the learning device. Learn with FrancWolfe SSVM according to the explanation of ChainCRF of PyStruct.

Preparation of learner

from pystruct.models import ChainCRF
from pystruct.learners import FrankWolfeSSVM
model = ChainCRF()
ssvm = FrankWolfeSSVM(model=model, C=.1, max_iter=10)


ssvm.fit(X_train, y_train)
CPU times: user 1.25 s, sys: 17.4 ms, total: 1.27 s
Wall time: 1.3 s

FrankWolfeSSVM(C=0.1, batch_mode=False, check_dual_every=10,
        do_averaging=True, line_search=True, logger=None, max_iter=10,
        model=ChainCRF(n_states: 16, inference_method: max-product),
        n_jobs=1, random_state=None, sample_method='perm',
        show_loss_every=0, tol=0.001, verbose=0)

So what is the predicted score?

ssvm.score(X_test, y_test)

Check predictions for the test

X_test_predict = np.array(ssvm.predict(X_test))

fig, axes = plt.subplots(3,3, figsize=(20,6))
shf = np.arange(X_test.shape[0])
for ax in axes.ravel():
    ax.plot(data[shf[c]], label='data')
    ax.plot(X_test_predict[shf[c]], label='predict')
    ax.plot(y_test[shf[c]], label='true')
    c += 1

plt.legend(bbox_to_anchor=(1.1, 1.0), loc=2, borderaxespad=0.)


Check the learned w

ssvm.w.shape # = n_features * n_states + n_states**2

Pairwise weight w

plt.matshow(ssvm.w[n_features * n_states:].reshape(n_states, n_states))
plt.title("Transition parameters of the chain CRF.")


unary weight w

plt.matshow(ssvm.w[:n_features * n_states].reshape(n_states,n_features))
plt.title("Unary parameters of the chain CRF.")


Recommended Posts

How to learn structured SVM of ChainCRF with PyStruct
How to infer MAP estimate of HMM with PyStruct
[Hugo] Summary of how to add pages to sites built with Learn
"How to pass PATH" to learn with homebrew
How to specify attributes with Mock of python
How to implement "named_scope" of RubyOnRails with Django
How to Learn Kaldi with the JUST Corpus
How to infer MAP estimate of HMM with OpenGM
[How to!] Learn and play Super Mario with Tensorflow !!
Summary of how to share state with multiple functions
How to update with SQLAlchemy?
How to Alter with SQLAlchemy?
How to separate strings with','
How to Delete with SQLAlchemy?
How to enable Read / Write of net.Conn with context with golang
How to cancel RT with tweepy
How to extract features of time series data with PySpark Basics
Python: How to use async with
Summary of how to use pandas.DataFrame.loc
How to get the ID of Type2Tag NXP NTAG213 with nfcpy
How to use virtualenv with PowerShell
How to deal with imbalanced data
Summary of how to use pyenv-virtualenv
Here's a brief summary of how to get started with Django
How to get started with Scrapy
How to deal with DistributionNotFound errors
How to get started with Django
How to monitor the execution status of sqlldr with the pv command
How to Data Augmentation with PyTorch
Explain how to use TensorFlow 2.X with implementation of VGG16 / ResNet50
How to use FTP with Python
Node.js: How to kill offspring of a process started with child_process.fork ()
How to calculate date with python
Summary of how to use csvkit
How to INNER JOIN with SQLAlchemy
How to install Anaconda with pyenv
How to authenticate with Django Part 2
How to authenticate with Django Part 3
[EC2] How to take a screen capture of your smartphone with selenium
How to crop the lower right part of the image with Python OpenCV
[Introduction to Python] How to sort the contents of a list efficiently with list sort
[Image recognition] How to read the result of automatic annotation with VoTT
How to deal with garbled characters in json of Django REST Framework
Summary of how to build a LAMP + Wordpress environment with Sakura VPS
How to do arithmetic with Django template
[Blender] How to set shape_key with script
[Python] Summary of how to use pandas
How to title multiple figures with matplotlib
How to speed up instantiation of BeautifulSoup
How to get parent id with sqlalchemy
Learn to colorize monochrome images with Chainer
How to get rid of long comprehensions
How to set up SVM using Optuna
How to install DLIB with 2020 / CUDA enabled
How to use ManyToManyField with Django's Admin
How to use OpenVPN with Ubuntu 18.04.3 LTS
How to use Cmder with PyCharm (Windows)
How to prevent package updates with apt
How to work with BigQuery in Python
How to use Ass / Alembic with HtoA
How to deal with enum compatibility errors