[PYTHON] Décrire un réseau qui accepte les informations d'annotation des utilisateurs dans Keras

Lorsque j'ai reçu les informations d'annotation (pondération) de l'utilisateur, j'ai essayé de mettre en œuvre le modèle de réseau Deep Learning qui permute le traitement avec Keras, je vais donc résumer le contenu. Récemment, il a été principalement implémenté dans PyTorch, donc j'étais confus par la différence dans la méthode de description. Utilisez l'API fonctionnelle pour décrire des réseaux complexes dans Keras. Référence: page Qiita sur l'utilisation de l'API fonctionnelle keras

Dans l'API fonctionnelle, il est nécessaire de connecter les couches définies dans keras.layers. Il est nécessaire de mettre en œuvre en utilisant Lambda pour mettre une couche de traitement d'origine comme cette fois. Voici un exemple de code d'un réseau qui assume une tâche de reconnaissance d'image comme indiqué. SelfAttentionNetwork.png

Normalement, l'image d'origine est donnée comme données d'entrée sur le réseau. Cette fois, en plus de cela, une carte pondérée correspondant à l'image originale et un drapeau (0 ou 1) indiquant s'il faut utiliser la carte pondérée donnent un total de trois entrées. L'image d'origine, la carte pondérée et l'indicateur pour utiliser la carte pondérée sont définis dans keras.layers.Input car leur taille varie en fonction de l'ensemble de données d'entrée. Par conséquent, il n'est pas possible de juger simplement par l'instruction If et de changer le traitement.

Code du modèle de réseau

from keras.models import Model
from keras.layers import Conv2D, Activation, BatchNormalization, GlobalAveragePooling2D, Dense, Input, Lambda, Add, Multiply
from keras.backend import switch as k_switch
from keras.backend import equal as k_equal
import numpy as np

def net(x, user_weight_map, user_weight_map_flg, feature_ch=16):
 """
    x:Image originale
    user_weight_map:Carte pondérée donnée par l'utilisateur
    user_weight_map_flg:Drapeau pour utiliser la carte pondérée donnée par l'utilisateur
    """
    #Appliquer la convolution 4 fois
    h = Conv2d(feature_ch, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)

    h = Conv2d(feature_ch*2, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*4, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*8, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    #---------------------
    #Carte pondérée dans un réseau ramifié(Self Attention)Calculer
    bh = Conv2D(feature_ch*4, 3, strides=1, padding='same')(h)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    bh = Conv2D(feature_ch*2, 3, strides=1, padding='same')(bh)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    bh = Conv2D(2, 1, strides=1, padding='same')(bh)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    model_weight = Conv2D(1, 3, strides=1, padding='same')(bh)
    model_weight = BatchNormalization()(bh)
    model_weight = Activation(activation='sigmoid', name='model_weight_output')(bh)

    bh = Conv2D(2, 1, strides=1, padding='same')(bh)
    bh = GlobalAveragePooling2D()(ah)
    bh = Dense(1000)(bh)
    #---------------------

    #Lire les informations du drapeau et basculer entre l'utilisation de la carte pondérée calculée à partir du réseau ou l'utilisation de la carte pondérée créée par l'utilisateur
    weight_h = Lambda(lambda x: switch_weight_map(x), name='swith_weight_map')([h, model_weight, user_weight_map, user_weight_map_flg])

    h = Add(name='weight_map_add')([h, weight_h])

    h = Conv2d(feature_ch*16, 3, strides=2, padding='same')(h)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*32, 3, strides=2, padding='same')(h)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)

    h = GlobalAveragePooling2D()(h)
    h = Dense(1000)(h)
  
    return h, bh, model_weight

def switch_weight_map(inputs):
    feature_map = inputs[0]
    model_weight_map = inputs[1]
    user_weight_map = inputs[2]
    user_weight_map_flg = inputs[3]
    
    model_weight = Multiply()([feature_map, model_weight_map])
    user_weight = Multiply()([feature_map, user_weight_map])

    weight_cond = k_equal(user_weight_map_flg, 0)
    
    weight_h = k_switch(weight_cond, model_weight, user_weight)

    return weight_h

# Save Network Architecture
def save_network_param(save_path, feature_ch):
    param = {'base_feature_num':feature_ch}
    
    with open(save_path, 'w') as f:
        yaml.dump(param, f, default_flow_style=False)

# Load Network Architecture
def load_network_param(load_path):
    with open(load_path) as f:
        param = yaml.load(f)

    return param

Lors de la rotation du processus d'entraînement, si vous essayez d'enregistrer le modèle pour chaque époque avec l'argument save_weights_only = False dans la fonction de rappel keras.callbacks.ModelCheckpoint (), le message d'erreur est "Impossible de sélectionner les objets _thread.RLock". Quelque chose comme ça est sorti. De plus, lorsque j'ai essayé d'exporter le modèle avec model.to_json () ou model.to_yaml (), j'ai eu la même erreur. Il semblait que pickle ne pouvait pas être sérialisé en raison du fait qu'il y avait une entrée indéfinie jusqu'à ce que Lambda reçoive des données d'entrée. Dans keras.callbacks.ModelCheckpoint (), enregistrez l'argument avec save_weights_only = True. Préparez save_network_param () et load_network_param (), et pour utiliser le modèle créé par train dans predict, reproduisez la structure du réseau avec le code réseau et le fichier yaml exporté, et définissez le poids de chaque couche avec model.load_weights (). ..

Dans l'implémentation utilisant Lambda, l'astuce consistait à lister et à donner l'argument x comme [h, model_weight, user_weight_map, user_weight_map_flg]. Si vous ne prenez que user_weight_map_flg pour l'argument x de Lambda comme indiqué ci-dessous, Keras interprétera la structure du réseau et il ne sera pas possible de déterminer si model_weight sera connecté à d'autres couches lors de l'enregistrement ou du chargement. Impossible.

weight_h = Lambda(lambda x:k_switch(k_equal(x, 0), model_weight, user_weight), name='switch_weight_map')(user_weight_map_flg)

Informations de référence

https://stackoverflow.com/questions/52448652/attributeerror-nonetype-object-has-no-attribute-inbound-nodes-while-trying https://stackoverflow.com/questions/44855603/typeerror-cant-pickle-thread-lock-objects-in-seq2seq https://github.com/keras-team/keras/issues/8343 https://github.com/matterport/Mask_RCNN/issues/1126 https://stackoverflow.com/questions/53212672/read-only-mode-in-keras https://stackoverflow.com/questions/47066635/checkpointing-keras-model-typeerror-cant-pickle-thread-lock-objects/55229794#55229794 https://blog.shikoan.com/lambda_arguments/ https://github.com/keras-team/keras/issues/6621 https://stackoverflow.com/questions/59635570/keras-backend-k-switch-for-loss-function-error

Recommended Posts

Décrire un réseau qui accepte les informations d'annotation des utilisateurs dans Keras
En Python, j'ai créé un LINE Bot qui envoie des informations sur le pollen à partir des informations de localisation.
En Python, créez un décorateur qui accepte dynamiquement les arguments Créer un décorateur
Ecrire un réseau de co-auteurs dans un domaine spécifique en utilisant les informations d'arxiv