[PYTHON] CPickle.UnpicklingError in Chainer

J'ai fait référence à certains sites pour utiliser Chainer.

Cependant, lorsque j'ai exécuté train_imagenet.py pour apprendre ma propre image, l'erreur suivante s'est produite.

Erreur


cPickle.UnpicklingError: invalid load key, 

La partie correspondante est un traitement non-Pickle par la fonction appelée pickle.load du code de ↓

train_imagenet.py


# Prepare dataset
train_list = load_image_list(args.train, args.root)
val_list = load_image_list(args.val, args.root)
mean_image = pickle.load(open(args.mean, 'rb'))

La valeur de l'argument args.mean est un fichier appelé mean.npy, donc si vous recherchez la source de ce fichier ...

compute_mean.py


#!/usr/bin/env python
import argparse
import os
import sys

import numpy
from PIL import Image
import six.moves.cPickle as pickle


parser = argparse.ArgumentParser(description='Compute images mean array')
parser.add_argument('dataset', help='Path to training image-label list file')
parser.add_argument('--root', '-r', default='.',
                    help='Root directory path of image files')
parser.add_argument('--output', '-o', default='mean.npy',
                    help='path to output mean array')
args = parser.parse_args()

sum_image = None
count = 0
for line in open(args.dataset):
    filepath = os.path.join(args.root, line.strip().split()[0])
    image = numpy.asarray(Image.open(filepath)).transpose(2, 0, 1)
    if sum_image is None:
        sum_image = numpy.ndarray(image.shape, dtype=numpy.float32)
        sum_image[:] = image
    else:
        sum_image += image
    count += 1
    sys.stderr.write('\r{}'.format(count))
    sys.stderr.flush()

sys.stderr.write('\n')

mean = sum_image / count
pickle.dump(mean, open(args.output, 'wb'), -1)

Il semble que l'objet créé par la fonction numpy.ndarray est sorti dans un fichier appelé mean.npy par la fonction pickle.dump. En d'autres termes, l'entité de mean.npy est comme un flux d'octets d'un tableau NumPy.

Donc, au lieu de lire mean.npy comme non-Pickle dans train_imagenet.py, je l'ai modifié pour le lire comme un tableau NumPy.

train_imagenet.py


# Prepare dataset
train_list = load_image_list(args.train, args.root)
val_list = load_image_list(args.val, args.root)
# mean_image = pickle.load(open(args.mean, 'rb'))← cPickle lorsqu'il est lu comme non-Pickle.UnpicklingError
mean_image = np.load(args.mean) #Lire comme un tableau NumPy

Puis j'ai réussi à le lire.

Recommended Posts

CPickle.UnpicklingError in Chainer
Historique de DQN + Deep Q-Network écrit dans Chainer
Comment créer des données à mettre dans CNN (Chainer)