首页 > 解决方案 > 使用 keras 回调对当前批次进行预测

问题描述

我正在尝试使用 keras callbck 在批处理结束时进行预测,如下所示:

from keras.layers import Dense
from keras.models import Sequential
from keras.callbacks import Callback
from keras import backend as K
import tensorflow as tf
import numpy as np


class CollectOutputAndTarget(Callback):
    def __init__(self):
        super(CollectOutputAndTarget, self).__init__()
        self.targets = []  # collect y_true batches
        self.inputs = []  # collect y_true batches
        self.outputs = []  # collect y_pred batches
        self.preds = []

        # the shape of these 2 variables will change according to batch shape
        # to handle the "last batch", specify `validate_shape=False`
        self.var_y_true = tf.Variable(0., validate_shape=False)
        self.var_input = tf.Variable(0., validate_shape=False)
        self.var_y_pred = tf.Variable(0., validate_shape=False)


    def on_batch_end(self, batch, logs=None):
        # evaluate the variables and save them into lists
        self.targets.append(K.eval(self.var_y_true))
        batch_inp = K.eval(self.var_input)
        self.inputs.append(batch_inp)
        self.outputs.append(K.eval(self.var_y_pred))
        current_pred = self.model.predict(batch_inp)
        self.preds.append(current_pred)


# build a simple model
K.clear_session()
# have to compile first for model.targets and model.outputs to be prepared
model = Sequential([Dense(5, input_shape=(2,)), Dense(2)])
model.compile(loss='mse', optimizer='adam')

# initialize the variables and the `tf.assign` ops
cbk = CollectOutputAndTarget()
fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
           tf.assign(cbk.var_input, model.inputs[0], validate_shape=False),
           tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches}  # use `model._function_kwargs` if using `Model` instead of `Sequential`


# fit the model and check results
X = np.arange(10).reshape((5, 2))
Y = X*2

model.fit(X, Y, epochs=1, batch_size=3, callbacks=[cbk], shuffle=False)

我收到以下错误:


InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-114-adfad08009ad> in <module>
      3 Y = X*2
      4 
----> 5 model.fit(X, Y, epochs=1, batch_size=3, callbacks=[cbk], shuffle=False)

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1037                                         initial_epoch=initial_epoch,
   1038                                         steps_per_epoch=steps_per_epoch,
-> 1039                                         validation_steps=validation_steps)
   1040 
   1041     def evaluate(self, x=None, y=None,

/usr/local/lib/python3.6/dist-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
    202                     batch_logs[l] = o
    203 
--> 204                 callbacks.on_batch_end(batch_index, batch_logs)
    205                 if callback_model.stop_training:
    206                     break

/usr/local/lib/python3.6/dist-packages/keras/callbacks.py in on_batch_end(self, batch, logs)
    113         t_before_callbacks = time.time()
    114         for callback in self.callbacks:
--> 115             callback.on_batch_end(batch, logs)
    116         self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
    117         delta_t_median = np.median(self._delta_ts_batch_end)

<ipython-input-111-65feb418f9ec> in on_batch_end(self, batch, logs)
     19         self.inputs.append(batch_inp)
     20         self.outputs.append(K.eval(self.var_y_pred))
---> 21         current_pred = self.model.predict(batch_inp)
     22         self.preds.append(current_pred)

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in predict_on_batch(self, x)
   1272             ins = x
   1273         self._make_predict_function()
-> 1274         outputs = self.predict_function(ins)
   1275         return unpack_singleton(outputs)
   1276 

/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2713                 return self._legacy_call(inputs)
   2714 
-> 2715             return self._call(inputs)
   2716         else:
   2717             if py_any(is_tensor(x) for x in inputs):

/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2673             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
   2674         else:
-> 2675             fetched = self._callable_fn(*array_vals)
   2676         return fetched[:len(self.outputs)]
   2677 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

InvalidArgumentError: You must feed a value for placeholder tensor 'dense_2_target' with dtype float and shape [?,?]
     [[{{node dense_2_target}}]]

我能够在批处理结束时从self.var_y_pred分配给的变量中获取模型输出model.outputs[0]

但是,据我了解,此预测是在当前步骤的反向传播之前完成的。我的目标是能够使用模型版本对当前批次进行预测,该模型版本的权重已经通过当前批次训练进行了更新。

我怎样才能做到这一点?

标签: pythontensorflowkeras

解决方案


答案是“你不能”。

对象model.inputsmodel.outputs是“张量”列表,而不是数据。张量是空图表示。

获得批量预测的唯一方法是调用model.predict_on_batch(input_data_as_numpy)或类似方法。这意味着让模型在您的情况下两次预测相同的事情。一个可怕的性能缺陷。

要在训练期间使用预测批次,您需要切换到使用 Eager 模式并进行自定义训练循环:


推荐阅读