[PYTHON] How to enable Keras Regressor to be saved with pickle or joblib

Overview

I get an error when trying to save KerasRegressor with pickle or joblib, How to make it saveable.

solution

Monkey patch Keras Regressor below

def KerasRegressor__getstate__(self):
    result = { 'sk_params': self.sk_params }
    with tempfile.TemporaryDirectory() as dir:
        if hasattr(self, 'model'): #There are cases where it does not exist due to cloning by the parent Estimator etc.
            self.model.save(dir + '/output.h5', include_optimizer=False)
            with open(dir + '/output.h5', 'rb') as f:
                result['model'] = f.read()
    return result
KerasRegressor.__getstate__ = KerasRegressor__getstate__

def KerasRegressor__setstate__(self, serialized):
    self.sk_params = serialized['sk_params']
    with tempfile.TemporaryDirectory() as dir:
        model_data = serialized.get('model')
        if model_data:
            with open(dir + '/input.h5', 'wb') as f:
                f.write(model_data)
            self.model = models.load_model(dir + '/input.h5')
KerasRegressor.__setstate__ = KerasRegressor__setstate__

Commentary

__getstate__, __setstate__Can be used to customize the serialization and deserialization of pickle for each class.(For details)

Recommended Posts

How to enable Keras Regressor to be saved with pickle or joblib
How to reduce GPU memory usage with Keras
How to debug with Jupyter or iPython Notebook
can't pickle annoy. How to deal with Annoy objects
How to process camera images with Teams or Zoom
[TensorFlow 2 / Keras] How to run learning with CTC Loss in Keras
How to enable Read / Write of net.Conn with context with golang
For beginners, how to deal with common errors in keras
How to update with SQLAlchemy?
How to cast with Theano
How to Alter with SQLAlchemy?
How to separate strings with','
How to RDP with Fedora31
How to Delete with SQLAlchemy?
Save & load data with joblib, pickle
Python: How to use async with
How to use virtualenv with PowerShell
How to deal with imbalanced data
How to install python-pip with ubuntu20.04LTS
How to deal with imbalanced data
How to get started with Scrapy
How to get started with Python
How to deal with DistributionNotFound errors
How to get started with Django
How to Data Augmentation with PyTorch
How to use FTP with Python
How to calculate date with python
How to install mysql-connector with pip3
How to INNER JOIN with SQLAlchemy
How to install Anaconda with pyenv
How to authenticate with Django Part 2
How to authenticate with Django Part 3
[Python] How to deal with the is instance error "is instance () arg 2 must be a type or tuple of types"