首页 > 解决方案 > 正确使用多输出/多标签模型的 tensorflow flow_from_dataframe

问题描述

在尝试拟合我的模型时,我遇到了形状不兼容的问题——不确定问题是否DataIterator出在我的模型定义中的实例化或某些东西上......

Pastebin 的代码重现

注意:要将上面的代码示例与下面的示例数据一起使用,您必须修改directory=None为在运行此代码的同一目录中directory=data_sample并解压data_sample.tgz

运行此程序以训练单个预测头年龄模型:

python tf_train_copy.py --csv_file data_sample/dataframe.csv --count 100 --training_type age

ValueError运行此程序以训练两个预测头年龄和性别模型,该模型由于与此问题所针对的形状相关而失败:

python tf_train_copy.py --csv_file data_sample/dataframe.csv --count 100 --training_type age_gender

100 张图像和相关数据帧的小样本

我有一个看起来像这样的数据框——包含二进制gender数据{'m', 'f'}age具有 3 个值的数据:

                                                 filename gender mask age_group
218069  /home/lrm/data/VGGFace2_cropped/n005204_0182_0...      f    f  (26, 39)
295315  /home/lrm/data/VGGFace2_cropped/n000162_0085_0...      m    f  (18, 25)
176301  /home/lrm/data/VGGFace2_cropped/n005378_0359_0...      f    f  (26, 39)
212662  /home/lrm/data/VGGFace2_cropped/n006412_0026_0...      m    f  (40, 55)
327910  /home/lrm/data/VGGFace2_cropped/n005114_0240_0...      m    f  (18, 25)
...                                                   ...    ...  ...       ...
283903  /home/lrm/data/VGGFace2_cropped/n002902_0158_0...      m    f  (26, 39)
156748  /home/lrm/data/VGGFace2_cropped/n003909_0464_0...      m    f  (40, 55)
89294   /home/lrm/data/VGGFace2_cropped/n000332_0195_0...      f    f  (18, 25)
156880  /home/lrm/data/VGGFace2_cropped/n005892_0122_0...      f    f  (18, 25)
304084  /home/lrm/data/VGGFace2_cropped/n002321_0114_0...      m    f  (18, 25)

以及这样定义的模型:

dense_age = tf.keras.layers.Dense(units=64, activation=tf.keras.layers.ReLU(6.0), name='dense_age')(base_model.output)
dropout_age = tf.keras.layers.Dropout(rate=args.dropout)(dense_age)
pred_age = tf.keras.layers.Dense(units=3, activation='softmax', name='pred_age')(dropout_age)

dense_gender = tf.keras.layers.Dense(units=64, activation=tf.keras.layers.ReLU(6.0), name='dense_gender')(base_model.output)
dropout_gender = tf.keras.layers.Dropout(rate=args.dropout)(dense_gender)
# also tried units=2 for gender prediction
pred_gender = tf.keras.layers.Dense(units=1, activation='softmax', name='pred_gender')(dropout_gender)

model = tf.keras.models.Model(inputs=base_model.input, outputs=[pred_age, pred_gender])
model.compile(
        optimizer=optimizer,
        loss=[tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1), tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1)],
        metrics=[tf.keras.metrics.Accuracy(name='age_accuracy'), tf.keras.metrics.Accuracy(name='gender_accuracy')])

这就是我的DataIterator对象的创建方式:

train_generator = train_datagen.flow_from_dataframe(
        directory=None,
        dataframe=train_data,
        x_col='filename',
        y_col=['age_group', 'gender'],
        class_mode='multi_output',
        **dataflow_kwargs)

我得到一个ValueError说法是形状不兼容:

Traceback (most recent call last):
  File "tf_train.py", line 385, in <module>
    validation_steps=validation_steps)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
    return func(*args, **kwargs)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1829, in fit_generator
    initial_epoch=initial_epoch)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
    *args, **kwds))
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:806 train_function  *
        return step_function(self, iterator)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:796 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
        return fn(*args, **kwargs)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:789 run_step  **
        outputs = model.train_step(data)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:749 train_step
        y, y_pred, sample_weight, regularization_losses=self.losses)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/compile_utils.py:204 __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/losses.py:149 __call__
        losses = ag_call(y_true, y_pred)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/losses.py:253 call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/losses.py:1535 categorical_crossentropy
        return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/backend.py:4687 categorical_crossentropy
        target.shape.assert_is_compatible_with(output.shape)
    /home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/framework/tensor_shape.py:1134 assert_is_compatible_with
        raise ValueError("Shapes %s and %s are incompatible" % (self, other))

    ValueError: Shapes (None, 1) and (None, 3) are incompatible

标签: tensorflowmultilabel-classificationmulticlass-classification

解决方案


推荐阅读