Ich erhalte eine Fehlermeldung, wenn ich versuche, KerasRegressor mit pickle oder joblib zu speichern. Wie man es speicherbar macht.
Affen Patch Keras Regressor unten
def KerasRegressor__getstate__(self):
result = { 'sk_params': self.sk_params }
with tempfile.TemporaryDirectory() as dir:
if hasattr(self, 'model'): #Es gibt Fälle, in denen es aufgrund eines Klons durch den übergeordneten Schätzer usw. nicht existiert.
self.model.save(dir + '/output.h5', include_optimizer=False)
with open(dir + '/output.h5', 'rb') as f:
result['model'] = f.read()
return result
KerasRegressor.__getstate__ = KerasRegressor__getstate__
def KerasRegressor__setstate__(self, serialized):
self.sk_params = serialized['sk_params']
with tempfile.TemporaryDirectory() as dir:
model_data = serialized.get('model')
if model_data:
with open(dir + '/input.h5', 'wb') as f:
f.write(model_data)
self.model = models.load_model(dir + '/input.h5')
KerasRegressor.__setstate__ = KerasRegressor__setstate__
__getstate__, __setstate__
Kann verwendet werden, um die Serialisierung und Deserialisierung von Gurken für jede Klasse anzupassen.(Für Details)
Recommended Posts