[PYTHON] Chanson auto-exploitée par apprentissage en profondeur (édition Stacked LSTM) [DW Day 6]

0. Grosso modo

1. Qu'est-ce que LSTM?

Voir ci-dessous.

2. Qu'est-ce que Stacked LSTM?

Un réseau neuronal avec plusieurs couches de LSTM. On s'attend à ce que des corrélations longues et des corrélations courtes puissent être apprises dans chaque couche en créant les couches. À propos, il existe également un réseau appelé Grid LSTM qui connecte les LSMT dans les directions verticale et horizontale pour le rendre multidimensionnel. Il semble que la tâche de prédiction des caractères de Wikipédia et la tâche de traduction chinoise fonctionnent bien.

3. Code de chaîne

J'ai créé un réseau neuronal comme indiqué dans la figure ci-dessous.

La couche d'entrée et la couche de sortie sont des vecteurs one-hot. Les quatre couches intermédiaires (orange) sont les couches LSTM.

class rnn(Chain):
  state = {}

  def __init__(self, n_vocab, n_units):
    print n_vocab, n_units
    super(rnn, self).__init__(
      l_embed = L.EmbedID(n_vocab, n_units),
      l1_x = L.Linear(n_units, 4 * n_units),
      l1_h = L.Linear(n_units, 4 * n_units),
      l2_x = L.Linear(n_units, 4 * n_units),
      l2_h = L.Linear(n_units, 4 * n_units),
      l3_x=L.Linear(n_units, 4 * n_units),
      l3_h=L.Linear(n_units, 4 * n_units),
      l4_x=L.Linear(n_units, 4 * n_units),
      l4_h=L.Linear(n_units, 4 * n_units),
      l_umembed = L.Linear(n_units, n_vocab)
    )

  def forward(self, x, t, train=True, dropout_ratio=0.5):
    h0 = self.l_embed(x)
    c1, h1 = F.lstm(
      self.state['c1'],
      F.dropout( self.l1_x(h0), ratio=dropout_ratio, train=train ) + self.l1_h(self.state['h1'])
    )
    c2, h2 = F.lstm(
      self.state['c2'],
      F.dropout( self.l2_x(h1), ratio=dropout_ratio, train=train ) + self.l2_h(self.state['h2'])
    )
    c3, h3 = F.lstm(
      self.state['c3'],
      F.dropout( self.l3_x(h2), ratio=dropout_ratio, train=train ) + self.l3_h(self.state['h3'])
    )
    c4, h4 = F.lstm(
      self.state['c4'],
      F.dropout( self.l4_x(h3), ratio=dropout_ratio, train=train ) + self.l4_h(self.state['h4'])
    )
    y = self.l_umembed(h4)
    self.state = {'c1': c1, 'h1': h1, 'c2': c2, 'h2': h2, 'c3': c3, 'h3': h3, 'c4': c4, 'h4': h4}
    if train:
      return F.softmax_cross_entropy(y, t), F.accuracy(y, t)
    else:
      return F.softmax(y), y.data

  def initialize_state(self, n_units, batchsize=1, train=True):
    for name in ('c1', 'h1', 'c2', 'h2', 'c3', 'h3', 'c4', 'h4'):
      self.state[name] = Variable(np.zeros((batchsize, n_units), dtype=np.float32), volatile=not train)

4. Expérience (j'ai essayé d'améliorer les performances de la génération automatique de musique)

Je souhaite améliorer les performances de la chanson auto-exploitée dans l'article précédent (RNN + LSTM [DW Day 1]). La dernière fois, il s'agissait de LSTM à 2 couches, mais je les ai remplacés par les LSTM à 4 couches mentionnés ci-dessus et j'ai essayé de réapprendre.

Données d'entraînement

J'ai utilisé les mêmes données midi que la dernière fois. Cependant, midi doit être converti au format de données texte. Le code à convertir en texte est également répertorié ci-dessous. Dans le code ci-dessous, vous ne pouvez extraire qu'une seule piste et la convertir en texte en utilisant python midi2text.py --midi foo.midi```. 0_80_40_00 0_90_4f_64 0_90_43_64 120_80_4f_00 ・ ・ ・ , 1 ou 2 octets suivant l'heure delta, octet d'état, octet d'état, concaténés avec un trait de soulignement (appelé "morceau") sont alignés séparés par des espaces d'une demi-largeur. Des données textuelles sont générées. Ces données textuelles ont été entraînées par le LSTM susmentionné en tant que données d'apprentissage.

Courbe d'apprentissage

Génération de chansons

Après l'apprentissage, une série de blocs a été générée en utilisant le même réseau. La séquence de blocs générée a été transformée en un fichier midi en utilisant le code décrit plus loin.

Cela s'est produit → Lecture (Attention! Le son sera joué immédiatement)

Impressions

La chanson était plutôt bonne. Cependant, il semble que l'influence du midi sous-jacent soit assez forte (restante). Je m'attendais aussi à ce que la chanson soit structurée avec Stacked LSTM, mais c'était plus monotone que je ne l'avais imaginé. Pour faire bouger la chanson, il est préférable d'ajouter intentionnellement du bruit irrégulier (son qui transfère la série qui s'est poursuivie jusque-là) au moment de la sortie de la série, au lieu de tout générer automatiquement. Je pense que vous pouvez l'obtenir. Du point de vue de la composition, il y a d'autres problèmes que la sélection sonore. L'ajustement des tonalités, l'ajout d'effets et les sessions avec plusieurs instruments sont également des problèmes.

code

--midi → text ( delta time_status byte_status byte suivi par les données réelles '' `format)

!/usr/bin/env python
 -*- coding: utf-8 -*-

import sys
import os
import struct
from binascii import *
from types import *
reload(sys)
sys.setdefaultencoding('utf-8')

def is_eq_0x2f(b):
  return int(b2a_hex(b), 16) == int('2f', 16)

def is_gte_0x80(b):
  return int(b2a_hex(b), 16) >= int('80', 16)

def is_eq_0xff(b):
  return int(b2a_hex(b), 16) == int('ff', 16)

def is_eq_0xf0(b):
  return int(b2a_hex(b), 16) == int('f0', 16)

def is_eq_0xf7(b):
  return int(b2a_hex(b), 16) == int('f7', 16)

def is_eq_0x8n(b):
  return int(b2a_hex(b), 16) >= int('80', 16) and int(b2a_hex(b), 16) <= int('8f', 16)

def is_eq_0x9n(b):
  return int(b2a_hex(b), 16) >= int('90', 16) and int(b2a_hex(b), 16) <= int('9f', 16)

def is_eq_0xan(b): # An: 3byte
  return int(b2a_hex(b), 16) >= int('a0', 16) and int(b2a_hex(b), 16) <= int('af', 16)

def is_eq_0xbn(b): # Bn: 3byte
  return int(b2a_hex(b), 16) >= int('b0', 16) and int(b2a_hex(b), 16) <= int('bf', 16)

def is_eq_0xcn(b): # Cn: 2byte
  return int(b2a_hex(b), 16) >= int('c0', 16) and int(b2a_hex(b), 16) <= int('cf', 16)

def is_eq_0xdn(b): # Dn: 2byte
  return int(b2a_hex(b), 16) >= int('d0', 16) and int(b2a_hex(b), 16) <= int('df', 16)

def is_eq_0xen(b): # En: 3byte
  return int(b2a_hex(b), 16) >= int('e0', 16) and int(b2a_hex(b), 16) <= int('ef', 16)

def is_eq_0xfn(b):
  return int(b2a_hex(b), 16) >= int('f0', 16) and int(b2a_hex(b), 16) <= int('ff', 16)

def mutable_lengths_to_int(bs):
  length = 0
  for i, b in enumerate(bs):
    if is_gte_0x80(b):
      length += ( int(b2a_hex(b), 16) - int('80', 16) ) * pow(int('80', 16), len(bs) - i - 1)
    else:
      length += int(b2a_hex(b), 16)
  return length

def int_to_mutable_lengths(length):
  length = int(length)
  bs = []
  append_flag = False
  for i in range(3, -1, -1):
    a = length / pow(int('80', 16), i)
    length -= a * pow(int('80', 16), i)
    if a > 0:
      append_flag = True
    if append_flag:
      if i > 0:
        bs.append(hex(a + int('80', 16))[2:].zfill(2))
      else:
        bs.append(hex(a)[2:].zfill(2))
  return bs if len(bs) > 0 else ['00']

def read_midi(path_to_midi):
  midi = open(path_to_midi, 'rb')
  data = {'header': [], 'tracks': []}
  track = {'header': [], 'chunks': []}
  chunk = {'delta': [], 'status': [], 'meta': [], 'length': [], 'body': []}
  current_status = None

  """
  Load data.header
  """
  bs = midi.read(14)
  data['header'] = [b for b in bs]

  while 1:
    """
    Load data.tracks[0].header
    """
    if len(track['header']) == 0:
      bs = midi.read(8)
      if bs == '':
        break
      track['header'] = [b for b in bs]

    """
    Load data.tracks[0].chunks[0]
    """
    # delta time
    # ----------
    b = midi.read(1)
    while 1:
      chunk['delta'].append(b)
      if is_gte_0x80(b):
        b = midi.read(1)
      else:
        break

    # status
    # ------
    b = midi.read(1)
    if is_gte_0x80(b):
      chunk['status'].append(b)
      current_status = b
    else:
      midi.seek(-1, os.SEEK_CUR)
      chunk['status'].append(current_status)

    # meta and length
    # ---------------
    if is_eq_0xff(current_status): # meta event
      b = midi.read(1)
      chunk['meta'].append(b)
      b = midi.read(1)
      while 1:
        chunk['length'].append(b)
        if is_gte_0x80(b):
          b = midi.read(1)
        else:
          break
      length = mutable_lengths_to_int(chunk['length'])
    elif is_eq_0xf0(current_status) or is_eq_0xf7(current_status): # sysex event
      b = midi.read(1)
      while 1:
        chunk['length'].append(b)
        if is_gte_0x80(b):
          b = midi.read(1)
        else:
          break
      length = mutable_lengths_to_int(chunk['length'])
    else: # midi event
      if is_eq_0xcn(current_status) or is_eq_0xdn(current_status):
        length = 1
      else:
        length = 2

    # body
    # ----
    for i in range(0, length):
      b = midi.read(1)
      chunk['body'].append(b)

    track['chunks'].append(chunk)


    if is_eq_0xff(chunk['status'][0]) and is_eq_0x2f(chunk['meta'][0]):
      data['tracks'].append(track)
      track = {'header': [], 'chunks': []}
    chunk = {'delta': [], 'status': [], 'meta': [], 'length': [], 'body': []}

  return data

def write_text(tracks):
  midi = open('out.txt', 'w')
  for track in tracks:
    for chunks in track:
      midi.write('{} '.format(chunks))

if __name__ == '__main__':
  from argparse import ArgumentParser
  parser = ArgumentParser(description='audio RNN')
  parser.add_argument('--midi', type=unicode, default='', help='path to the MIDI file')
  args = parser.parse_args()
  
  data = read_midi(args.midi)

  # extract midi track
 track_list = [1] # ← Numéro de piste que vous souhaitez extraire

  tracks = []
  for n in track_list:
    raw_data = []
    chunks = data['tracks'][n]['chunks']
    for i in range(0, len(chunks)):
      chunk = chunks[i]
      if is_eq_0xff(chunk['status'][0]) or \
         is_eq_0xf0(chunk['status'][0]) or \
         is_eq_0xf7(chunk['status'][0]) :
        continue
      raw_data.append('_'.join(
        [str(mutable_lengths_to_int(chunk['delta']))] +
        [str(b2a_hex(chunk['status'][0]))] +
        [str(b2a_hex(body)) for body in chunk['body']]
      ))
    tracks.append(raw_data)

  write_text(tracks)

--Text (delta time_status byte_status byte suivi par les données réelles`` format) → midi

!/usr/bin/env python
 -*- coding: utf-8 -*-

import sys
import os
import struct
from binascii import *
from types import *
reload(sys)
sys.setdefaultencoding('utf-8')

def int_to_mutable_lengths(length):
  length = int(length)
  bs = []
  append_flag = False
  for i in range(3, -1, -1):
    a = length / pow(int('80', 16), i)
    length -= a * pow(int('80', 16), i)
    if a > 0:
      append_flag = True
    if append_flag:
      if i > 0:
        bs.append(hex(a + int('80', 16))[2:].zfill(2))
      else:
        bs.append(hex(a)[2:].zfill(2))
  return bs if len(bs) > 0 else ['00']

def write_midi(tracks):
  print len(tracks)
  midi = open('out.midi', 'wb')

  """
  MIDI Header
  """
  header_bary = bytearray([])
  header_bary.extend([0x4d, 0x54, 0x68, 0x64, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00])
  header_bary.extend([int(hex(len(tracks))[2:].zfill(4)[i:i+2], 16) for i in range(0, 4, 2)])
  header_bary.extend([0x01, 0xe0])
  midi.write(header_bary)

  for track in tracks:
    track_bary = bytearray([])
    for chunk in track:
      # It is assumed that each chunk consists of just 4 elements
      if len(chunk.split('_')) != 4:
        continue
      int_delta, status, data1, data2 = chunk.split('_')

      if status[0] == '8' or status[0] == '9' or status[0] == 'a' or status[0] == 'b' or status[0] == 'e': # 3byte
        delta = int_to_mutable_lengths(int_delta)
        track_bary.extend([int(d, 16) for d in delta])
        track_bary.extend([int(status, 16)]) 
        track_bary.extend([int(data1, 16)])  
        track_bary.extend([int(data2, 16)])  
      elif status[0] == 'c' or status[0] == 'd':
        delta = int_to_mutable_lengths(int_delta)
        track_bary.extend([int(d, 16) for d in delta])
        track_bary.extend([int(status, 16)]) 
        track_bary.extend([int(data1, 16)])  
      else:
        print status[0]

    """
    Track header
    """
    header_bary = bytearray([])
    header_bary.extend([0x4d, 0x54, 0x72, 0x6b])
    header_bary.extend([int(hex(len(track_bary)+4)[2:].zfill(8)[i:i+2], 16) for i in range(0, 8, 2)])
    midi.write(header_bary)

    """
    Track body
    """
    print len(track_bary)
    midi.write(track_bary)

    """
    Track footer
    """
    footer_bary = bytearray([])
    footer_bary.extend([0x00, 0xff, 0x2f, 0x00])
    midi.write(footer_bary)

if __name__ == '__main__':

 # ↓ Arrange le format de "delta time_status byte_actual data following status byte" séparés par des espaces
 # Ne fonctionne pas bien si l'état d'exécution est inclus. .. ..
  txt = '0_80_40_00 0_90_4f_64 0_90_43_64 120_80_4f_00 0_80_43_00 0_90_51_64 0_90_45_64 480_80_51_00 0_80_45_00 0_90_4c_64 0_90_44_64 120_80_4c_00 0_80_44_00 0_90_4f_64 0_90_43_64 60_80_4f_00 0_80_43_00 0_90_4d_64 0_90_41_64 120_80_4d_00'
  tracks = [txt.split(' ')]
  write_midi(tracks)

Lien

Étudiez en profondeur l'apprentissage en profondeur [DW Day 0]

Recommended Posts

Chanson auto-exploitée par apprentissage en profondeur (édition Stacked LSTM) [DW Day 6]
Étudiez en profondeur le Deep Learning [DW Day 0]
Apprentissage profond / code de travail LSTM
<Cours> Apprentissage en profondeur: Jour 1 NN
Sujets> Deep Learning: Day3 RNN
[Apprentissage en profondeur] Classification d'images avec un réseau neuronal convolutif [DW jour 4]
Apprentissage profond appris par l'implémentation 1 (édition de retour)
Introduction au Deep Learning ~ Dropout Edition ~
Minutes d'étude: Jour 1
Premier jour d'étude de Python
Étudiez en profondeur le Deep Learning [DW Day 0]
[Rabbit Challenge (E qualification)] Apprentissage en profondeur (jour2)
Deep learning 2 appris par l'implémentation (classification d'images)
Apprentissage profond à partir de zéro (propagation vers l'avant)
[Rabbit Challenge (E qualification)] Apprentissage en profondeur (jour3)
<Cours> Deep Learning Day4 Renforcement de l'apprentissage / flux de tension
Produisez de belles vaches de mer par apprentissage profond
Détection d'objets par apprentissage profond pour comprendre en profondeur par Keras
[Rabbit Challenge (E qualification)] Deep learning (day4)
Fiche d'apprentissage (2ème jour) Scraping par #BeautifulSoup