首页 > 解决方案 > 无法在 Python 中腌制 Tensorflow 对象 - TypeError:无法腌制 _thread._local 对象

问题描述

我想在 tensorflow 上运行 keras fit 后腌制历史对象。但我收到一个错误。

import gzip
import numpy as np
import os
import pickle
import tensorflow as tf
from tensorflow import keras


with gzip.open('mnist.pkl.gz', 'rb') as f:
    train_set, test_set = pickle.load(f, encoding='latin1')

X_train = np.asarray(train_set[0])
y_train = np.asarray(train_set[1])

X_test = np.asarray(test_set[0])
y_test = np.asarray(test_set[1])

X_valid, X_train = X_train[:5000]/255.0, X_train[5000:]/255.0
y_valid, y_train = y_train[:5000], y_train[5000:]

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot']

model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
model.add(keras.layers.Dense(300, activation = 'relu'))
model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(10, activation = 'softmax'))
model.summary()

model.compile(loss='sparse_categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])

history = model.fit(X_train, y_train, epochs=1,
                    validation_data =(X_valid, y_valid))

if not os.path.isdir('models'):
    os.mkdir('models')

model.save('models/basic.h5')
with open('models/basic_history.pickle', 'wb') as f:
    pickle.dump(history, f)

它给了我以下错误:

Traceback (most recent call last):
  File "main.py", line 69, in <module>
    pickle.dump(history, f)
TypeError: can't pickle _thread._local objects

PS:要运行代码,请下载 fashion_mnist 数据:https ://s3.amazonaws.com/img-datasets/mnist.pkl.g

标签: pythontensorflowpickle

解决方案


正如卡尔建议的那样,历史对象不能被腌制。但它的字典可以:

with open('models/basic_history.pickle', 'wb') as f:
    pickle.dump(history.history, f)

推荐阅读