[PYTHON] Wie man strukturiertes SVM von ChainCRF mit PyStruct lernt

Die Dokumente und Beispiele im Original-Web sind unfreundlich, daher habe ich versucht, leicht verständliche Daten zu verwenden.

Erste Vorbereitung


import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from pystruct.inference import inference_dispatch

Der Inhalt ist die Rauschentfernung von Zeitreihendaten wie in Implementieren von HMM mit PyStruct. Zum Lernen wird eine Zeitreihe mit Rauschen verwendet, das zu einer festen Zeitreihe hinzugefügt wird. (Abgesehen von der Tatsache, dass es behoben ist, so dass Sie nicht schließen müssen)

Erstellung von Trainingsdaten


n_samples = 500

d = np.array([12, 12, 11, 11, 10,  9,  8,  8,  7,  6,  6,  6,  7,  8,  8,  8,  6,
        5,  4,  3,  3,  3,  2,  1,  0,  1,  3,  4,  5,  6,  8,  8,  9,  9,
       10, 11, 12, 13, 14, 14, 14, 15, 15, 15, 15])
n_nodes = d.shape[0]
n_states = np.unique(d).shape[0]
n_features = n_states + 1 # add bias

y = np.repeat(d[np.newaxis,:], n_samples, axis=0)

data = y + (np.random.rand(n_samples, n_nodes)-0.5)*5

# negative sign for maximization !
X = np.array( [ [ [ -abs(i-j)**0.1 for j in range(n_states)]  for i in dd ] for dd in data] )

# add constant features for bias
X = np.array( [np.hstack((X[i], 0.1*np.ones((X[i].shape[0],1)))) for i in range(X.shape[0])] )

Daten X verfügt über 500 Zahlen, 45 Zeitreihenlängen, 16 Zustände / Klassen und 17 Merkmale (SVM-Bias).

Größe prüfen


X.shape, y.shape
===
((500, 45, 17), (500, 45))

Teilen Sie das Lernen und Testen wie gewohnt


from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)

Überprüfen Sie die Trainingsdaten


fig, axes = plt.subplots(3,3, figsize=(20,6))
c=0
for ax in axes.ravel():
    ax.plot(data[c], label='data')
    ax.plot(y_train[c], label='true')
    ax.set_xticks(())
    ax.set_yticks(())
    c += 1
plt.legend(bbox_to_anchor=(1.1, 1.0), loc=2, borderaxespad=0.)

Unknown1.png

Vergleich der Trainingsdaten X (Merkmale zu jedem Zeitpunkt) und y (echte feste Zeitreihen) zur Bestätigung.

Bestätigung


plt.matshow(np.flipud(X_train[0,:,:-1].T)) # remove bias
plt.colorbar()
plt.yticks(())
#plt.show()

plt.plot(15-y_train[0]) # flipud
plt.show()

Unknown2.png

Bereiten Sie nun den Lernenden vor. Lernen Sie mit FrancWolfe SSVM gemäß der Erklärung von ChainCRF von PyStruct.

Vorbereitung des Lernenden


from pystruct.models import ChainCRF
from pystruct.learners import FrankWolfeSSVM
model = ChainCRF()
ssvm = FrankWolfeSSVM(model=model, C=.1, max_iter=10)

Lernen!


%%time
ssvm.fit(X_train, y_train)
====
CPU times: user 1.25 s, sys: 17.4 ms, total: 1.27 s
Wall time: 1.3 s

FrankWolfeSSVM(C=0.1, batch_mode=False, check_dual_every=10,
        do_averaging=True, line_search=True, logger=None, max_iter=10,
        model=ChainCRF(n_states: 16, inference_method: max-product),
        n_jobs=1, random_state=None, sample_method='perm',
        show_loss_every=0, tol=0.001, verbose=0)

Was ist dann die vorhergesagte Punktzahl?


ssvm.score(X_test, y_test)
==========
0.56377777777777771

Überprüfen Sie die Vorhersagen für den Test


X_test_predict = np.array(ssvm.predict(X_test))

fig, axes = plt.subplots(3,3, figsize=(20,6))
shf = np.arange(X_test.shape[0])
np.random.shuffle(shf)
c=0
for ax in axes.ravel():
    ax.plot(data[shf[c]], label='data')
    ax.plot(X_test_predict[shf[c]], label='predict')
    ax.plot(y_test[shf[c]], label='true')
    ax.set_xticks(())
    ax.set_yticks(())
    c += 1

plt.legend(bbox_to_anchor=(1.1, 1.0), loc=2, borderaxespad=0.)

Unknown3.png

Überprüfen Sie das gelernte w


ssvm.w.shape # = n_features * n_states + n_states**2
========
(528,)

Paarweises Gewicht w


plt.matshow(ssvm.w[n_features * n_states:].reshape(n_states, n_states))
plt.title("Transition parameters of the chain CRF.")
plt.xticks(np.arange(n_states))
plt.yticks(np.arange(n_states))
plt.colorbar()
plt.show()

Unknown4.png

unäres Gewicht w


plt.matshow(ssvm.w[:n_features * n_states].reshape(n_states,n_features))
plt.title("Unary parameters of the chain CRF.")
plt.yticks(np.arange(n_states))
plt.xticks(np.arange(n_features))
plt.ylabel('states') 
plt.xlabel('features')
plt.colorbar()
plt.show()

Unknown5.png

Recommended Posts

Wie man strukturiertes SVM von ChainCRF mit PyStruct lernt
Ableiten der MAP-Schätzung von HMM mit PyStruct
[Hugo] Zusammenfassung zum Hinzufügen von Seiten zu der mit Learn erstellten Site
So legen Sie Attribute mit Mock of Python fest
So implementieren Sie "named_scope" von RubyOnRails mit Django
Wie man Kaldi mit JUST Corpus trainiert
Ableiten der MAP-Schätzung von HMM mit OpenGM
[How to!] Lerne und spiele Super Mario mit Tensorflow !!
Zusammenfassung, wie der Status mit mehreren Funktionen geteilt wird
Wie aktualisiere ich mit SQLAlchemy?
Wie mit SQLAlchemy ändern?
So trennen Sie Zeichenfolgen mit ','
Wie lösche ich mit SQLAlchemy?
So aktivieren Sie das Lesen / Schreiben von net.Conn mit golang, um mit dem Kontext abzubrechen
So brechen Sie RT mit tweepy ab
So extrahieren Sie Funktionen von Zeitreihendaten mit PySpark Basics
Python: So verwenden Sie Async mit
Zusammenfassung der Verwendung von pandas.DataFrame.loc
So erhalten Sie die ID von Type2Tag NXP NTAG213 mit nfcpy
So verwenden Sie virtualenv mit PowerShell
Zusammenfassung der Verwendung von pyenv-virtualenv
Ich habe versucht, das Entwicklungsstartverfahren von Django kurz zusammenzufassen
Wie fange ich mit Scrapy an?
Umgang mit dem DistributionNotFound-Fehler
Wie fange ich mit Django an?
So überwachen Sie den Ausführungsstatus von sqlldr mit dem Befehl pv
Aufblasen von Daten (Datenerweiterung) mit PyTorch
Erklärt, wie TensorFlow 2.X mit der Implementierung von VGG16 / ResNet50 verwendet wird
Node.js: So töten Sie Nachkommen eines Prozesses, der von child_process.fork () gestartet wurde
So berechnen Sie das Datum mit Python
Zusammenfassung der Verwendung von csvkit
So verbinden Sie INNER mit SQL Alchemy
So installieren Sie Anaconda mit pyenv
[EC2] So machen Sie mit Selen eine Bildschirmaufnahme Ihres Smartphones
So schneiden Sie den unteren rechten Teil des Bildes mit Python OpenCV
[Einführung in Python] So sortieren Sie den Inhalt einer Liste effizient mit Listensortierung
[Bilderkennung] Lesen des Ergebnisses der automatischen Annotation mit VoTT
Wie man mit verstümmelten Charakteren in json von Django REST Framework umgeht
Zusammenfassung zum Erstellen einer LAMP + Wordpress-Umgebung mit Sakura VPS
So führen Sie eine arithmetische Verarbeitung mit der Django-Vorlage durch
[Blender] So legen Sie shape_key mit dem Skript fest
[Python] Zusammenfassung der Verwendung von Pandas
Wie man mit matplotlib mehrere Figuren betitelt
Wie man die schöne Suppeninstanziierung beschleunigt
So erhalten Sie die Eltern-ID mit sqlalchemy
Lernen Sie mit Chainer, monochrome Bilder einzufärben
Wie man lange Einschlüsse loswird
So richten Sie SVM mit Optuna ein
So installieren Sie DLIB mit aktiviertem 2020 / CUDA
Verwendung von ManyToManyField mit Djangos Admin
Verwendung von OpenVPN mit Ubuntu 18.04.3 LTS
Verwendung von Cmder mit PyCharm (Windows)
So verhindern Sie Paketaktualisierungen mit apt
So arbeiten Sie mit BigQuery in Python
Wie man Ass / Alembic mit HtoA benutzt
Umgang mit Enum-Kompatibilitätsfehlern