[PYTHON] Describe a network that accepts annotation information from users in Keras

When I received annotation (weighting) information from the user, I tried and errored how to implement a deep learning network model that switches processing in Keras, so I will summarize the contents. Recently, it was mainly implemented in PyTorch, so I was confused by the difference in the description method. Use the functional API to describe complex networks in Keras. Reference: Qiita page of keras functional API usage memo

In the Functional API, it is necessary to connect the layers defined in keras.layers. It is necessary to implement using Lambda to put a layer of original processing like this time. Below is a code example of a network that assumes an image recognition task as shown. SelfAttentionNetwork.png

Normally, the original image is given as input data to the network. This time, in addition to that, a weighting map corresponding to the original image and a flag (0 or 1) as to whether to use the weighting map give a total of three inputs. Flags for using original images, weighted maps, and weighted maps are defined in keras.layers.Input as they vary in size depending on the input dataset. Therefore, it is not possible to simply judge by the If statement and switch the process.

Network model code

from keras.models import Model
from keras.layers import Conv2D, Activation, BatchNormalization, GlobalAveragePooling2D, Dense, Input, Lambda, Add, Multiply
from keras.backend import switch as k_switch
from keras.backend import equal as k_equal
import numpy as np

def net(x, user_weight_map, user_weight_map_flg, feature_ch=16):
 """
    x:Original image
    user_weight_map:User-given weighting map
    user_weight_map_flg:Flag to use weighted map given by user
    """
    #Apply Convolution 4 times
    h = Conv2d(feature_ch, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)

    h = Conv2d(feature_ch*2, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*4, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*8, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    #---------------------
    #Weighted map in branch network(Self Attention)To calculate
    bh = Conv2D(feature_ch*4, 3, strides=1, padding='same')(h)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    bh = Conv2D(feature_ch*2, 3, strides=1, padding='same')(bh)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    bh = Conv2D(2, 1, strides=1, padding='same')(bh)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    model_weight = Conv2D(1, 3, strides=1, padding='same')(bh)
    model_weight = BatchNormalization()(bh)
    model_weight = Activation(activation='sigmoid', name='model_weight_output')(bh)

    bh = Conv2D(2, 1, strides=1, padding='same')(bh)
    bh = GlobalAveragePooling2D()(ah)
    bh = Dense(1000)(bh)
    #---------------------

    #Read the flag information and switch between using the weighting map calculated from the network and using the weighting map created by the user.
    weight_h = Lambda(lambda x: switch_weight_map(x), name='swith_weight_map')([h, model_weight, user_weight_map, user_weight_map_flg])

    h = Add(name='weight_map_add')([h, weight_h])

    h = Conv2d(feature_ch*16, 3, strides=2, padding='same')(h)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*32, 3, strides=2, padding='same')(h)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)

    h = GlobalAveragePooling2D()(h)
    h = Dense(1000)(h)
  
    return h, bh, model_weight

def switch_weight_map(inputs):
    feature_map = inputs[0]
    model_weight_map = inputs[1]
    user_weight_map = inputs[2]
    user_weight_map_flg = inputs[3]
    
    model_weight = Multiply()([feature_map, model_weight_map])
    user_weight = Multiply()([feature_map, user_weight_map])

    weight_cond = k_equal(user_weight_map_flg, 0)
    
    weight_h = k_switch(weight_cond, model_weight, user_weight)

    return weight_h

# Save Network Architecture
def save_network_param(save_path, feature_ch):
    param = {'base_feature_num':feature_ch}
    
    with open(save_path, 'w') as f:
        yaml.dump(param, f, default_flow_style=False)

# Load Network Architecture
def load_network_param(load_path):
    with open(load_path) as f:
        param = yaml.load(f)

    return param

While turning the training process, if you try to save the model for each epoch with the argument save_weights_only = False in the callback function keras.callbacks.ModelCheckpoint (), the error message is "can't pickle _thread.RLock objects". Something like that came out. Also, when I tried to export the model with model.to_json () or model.to_yaml (), I got the same error. It seemed that it was not possible to serialize pickle because there was an amorphous Input until Lambda was given input data. In keras.callbacks.ModelCheckpoint (), save the argument with save_weights_only = True. Prepare save_network_param () and load_network_param (), and to use the model created by train in predict, reproduce the network structure with the network code and the exported yaml file, and set the weight of each layer with model.load_weights (). ..

In the implementation using Lambda, the trick was to give the argument x as a list like [h, model_weight, user_weight_map, user_weight_map_flg]. If you take only user_weight_map_flg for Lambda's argument x as shown below, Keras will interpret the network structure and it will not be possible to determine whether model_weight will be connected to other layers when saving or loading. Could not.

weight_h = Lambda(lambda x:k_switch(k_equal(x, 0), model_weight, user_weight), name='switch_weight_map')(user_weight_map_flg)

Reference information

https://stackoverflow.com/questions/52448652/attributeerror-nonetype-object-has-no-attribute-inbound-nodes-while-trying https://stackoverflow.com/questions/44855603/typeerror-cant-pickle-thread-lock-objects-in-seq2seq https://github.com/keras-team/keras/issues/8343 https://github.com/matterport/Mask_RCNN/issues/1126 https://stackoverflow.com/questions/53212672/read-only-mode-in-keras https://stackoverflow.com/questions/47066635/checkpointing-keras-model-typeerror-cant-pickle-thread-lock-objects/55229794#55229794 https://blog.shikoan.com/lambda_arguments/ https://github.com/keras-team/keras/issues/6621 https://stackoverflow.com/questions/59635570/keras-backend-k-switch-for-loss-function-error

Recommended Posts

Describe a network that accepts annotation information from users in Keras
In Python, I made a LINE Bot that sends pollen information from location information.
In Python, create a decorator that dynamically accepts arguments Create a decorator
Write a co-author network in a specific field using arxiv information