python - 如何恢复使用 Dataset API 训练的 Tensorflow 模型?
问题描述
我正在使用带有可馈送迭代器的 Dataset API 训练我的模型,就像在此处的导入数据教程中一样。问题是,在恢复模型时。它还将从训练中恢复手柄占位符的形状。这意味着它期望得到一个例子和一个标签。
def loadTFRecord(filenames):
dataset = tf.data.TFRecordDataset([filenames])
dataset = dataset.map(extract_img_func)
dataset = dataset.batch(batchsize)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, dataset.output_types, dataset.output_shapes)
training_iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
training_handle = self.sess.run(training_iterator.string_handle())
return next_element #next_element[0] is the example img, next_element[1] is the label
def model_fn(images, labels=None, train=False):
input_layer = images
...
predictions = last_layer
if train:
return predictions
# Calculate loss
loss = tf.losses.mean_squared_error(labels, predictions)
learning_rate = tf.train.exponential_decay(learning_rate=learningRate, staircase=True)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(
loss=loss,
global_step=global_step)
return train_op, predictions, loss
有了这个,我正在创建我的训练模型:
examples, labels = loadTFRecord("path/to/tfrecord")
model_fn(examples, labels=labels)
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=0.5)
... #training here
saver.save(sess, "path/to/")
现在的问题是,当我想恢复模型进行推理时。我想要做的是恢复模型并传入另一个可馈送迭代器,该迭代器从磁盘加载一些 .png 文件。我这样做类似于加载 TFRecord 文件。
def load_images(filenames):
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.map(lambda x: tf.image.resize_images(self.normalize(tf.image.decode_png(tf.read_file(x), channels = 3)), [IM_WIDTH, IM_HEIGHT]))
dataset = dataset.batch(1)
iterator = tf.data.Iterator.from_string_handle(handle, dataset.output_types, dataset.output_shapes)
iterator = dataset.make_one_shot_iterator()
next_img = iterator.get_next()
training_handle = sess.run(iterator.string_handle())
return next_img
现在的问题是当将它传递给恢复的模型时,如下所示:
saver = tf.train.import_meta_graph(modelbasepath + ".meta")
saver.restore(sess, modelbasepath)
... # restore operations here
# finally run predictions, error occurs here!
predictions = sess.run([predictions], feed_dict={handle: training_handle})
我收到此错误:
Number of components does not match: expected 2 types but got 1.
[[Node: IteratorFromStringHandle_2 = IteratorFromStringHandle[output_shapes=[[?,80,80,3], [?,80,80,?]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_Placeholder_1_0_0)]]
这告诉我,它也期望得到一个标签,而我只是提供一个图像来预测。
我该如何克服呢?有没有办法改变占位符的形状,或者如何实现这一点,以便可以恢复使用数据集 API 和可馈送字典训练的模型?
解决方案
我遇到了同样的问题。但是,我想不出一个干净的解决方案。我最终为加载图像时返回的标签创建了一个虚拟张量。可能有更好的方法来执行此操作,但此解决方案现在应该允许您运行模型。
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.map(load_images)
def load_images(x):
image = tf.image.decode_png(tf.read_file(x), channels = 3))
image = self.normalize(image)
image = tf.image.resize_images(image, [IM_WIDTH, IM_HEIGHT])
# Assuming label is one channel, can slice image to get correct dims
label = tf.zeros_like(image[:, :, 0:1])
return image, label
推荐阅读
- r - 将 R 降价参数传递给源 R 脚本
- javascript - 查找具有两个 CSS 类的元素
- javascript - 如何从对象属性中的数组中过滤掉项目
- reactjs - 利用本地状态并避免在数据已经可用时获取数据?
- ios - RPBroadcastSampleHandler 任何未调用的方法
- javascript - 无法获取未定义或空引用的属性“indexOf”
- r - 如何修复 R 中的此错误:[.data.frame(newdata, , object$method$center, drop = FALSE) : undefined columns selected
- iframe - HTA iframe 总是打开一个新窗口
- angular - 如何使用 Jest 在 Angular 上配置 Allure
- google-apps-script - 触发器在 Google Apps 脚本中的表单提交上不起作用