tensorflow - 正确使用多输出/多标签模型的 tensorflow flow_from_dataframe
问题描述
在尝试拟合我的模型时,我遇到了形状不兼容的问题——不确定问题是否DataIterator
出在我的模型定义中的实例化或某些东西上......
注意:要将上面的代码示例与下面的示例数据一起使用,您必须修改directory=None
为在运行此代码的同一目录中directory=data_sample
并解压data_sample.tgz
运行此程序以训练单个预测头年龄模型:
python tf_train_copy.py --csv_file data_sample/dataframe.csv --count 100 --training_type age
ValueError
运行此程序以训练两个预测头年龄和性别模型,该模型由于与此问题所针对的形状相关而失败:
python tf_train_copy.py --csv_file data_sample/dataframe.csv --count 100 --training_type age_gender
我有一个看起来像这样的数据框——包含二进制gender
数据{'m', 'f'}
和age
具有 3 个值的数据:
filename gender mask age_group
218069 /home/lrm/data/VGGFace2_cropped/n005204_0182_0... f f (26, 39)
295315 /home/lrm/data/VGGFace2_cropped/n000162_0085_0... m f (18, 25)
176301 /home/lrm/data/VGGFace2_cropped/n005378_0359_0... f f (26, 39)
212662 /home/lrm/data/VGGFace2_cropped/n006412_0026_0... m f (40, 55)
327910 /home/lrm/data/VGGFace2_cropped/n005114_0240_0... m f (18, 25)
... ... ... ... ...
283903 /home/lrm/data/VGGFace2_cropped/n002902_0158_0... m f (26, 39)
156748 /home/lrm/data/VGGFace2_cropped/n003909_0464_0... m f (40, 55)
89294 /home/lrm/data/VGGFace2_cropped/n000332_0195_0... f f (18, 25)
156880 /home/lrm/data/VGGFace2_cropped/n005892_0122_0... f f (18, 25)
304084 /home/lrm/data/VGGFace2_cropped/n002321_0114_0... m f (18, 25)
以及这样定义的模型:
dense_age = tf.keras.layers.Dense(units=64, activation=tf.keras.layers.ReLU(6.0), name='dense_age')(base_model.output)
dropout_age = tf.keras.layers.Dropout(rate=args.dropout)(dense_age)
pred_age = tf.keras.layers.Dense(units=3, activation='softmax', name='pred_age')(dropout_age)
dense_gender = tf.keras.layers.Dense(units=64, activation=tf.keras.layers.ReLU(6.0), name='dense_gender')(base_model.output)
dropout_gender = tf.keras.layers.Dropout(rate=args.dropout)(dense_gender)
# also tried units=2 for gender prediction
pred_gender = tf.keras.layers.Dense(units=1, activation='softmax', name='pred_gender')(dropout_gender)
model = tf.keras.models.Model(inputs=base_model.input, outputs=[pred_age, pred_gender])
model.compile(
optimizer=optimizer,
loss=[tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1), tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1)],
metrics=[tf.keras.metrics.Accuracy(name='age_accuracy'), tf.keras.metrics.Accuracy(name='gender_accuracy')])
这就是我的DataIterator
对象的创建方式:
train_generator = train_datagen.flow_from_dataframe(
directory=None,
dataframe=train_data,
x_col='filename',
y_col=['age_group', 'gender'],
class_mode='multi_output',
**dataflow_kwargs)
我得到一个ValueError
说法是形状不兼容:
Traceback (most recent call last):
File "tf_train.py", line 385, in <module>
validation_steps=validation_steps)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1829, in fit_generator
initial_epoch=initial_epoch)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
tmp_logs = train_function(iterator)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
*args, **kwds))
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:806 train_function *
return step_function(self, iterator)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:796 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:1211 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2585 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2945 _call_for_each_replica
return fn(*args, **kwargs)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:789 run_step **
outputs = model.train_step(data)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:749 train_step
y, y_pred, sample_weight, regularization_losses=self.losses)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/engine/compile_utils.py:204 __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/losses.py:149 __call__
losses = ag_call(y_true, y_pred)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/losses.py:253 call **
return ag_fn(y_true, y_pred, **self._fn_kwargs)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/losses.py:1535 categorical_crossentropy
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/keras/backend.py:4687 categorical_crossentropy
target.shape.assert_is_compatible_with(output.shape)
/home/lrm/.pyenv/versions/tf_training_231/lib/python3.7/site-packages/tensorflow/python/framework/tensor_shape.py:1134 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (None, 1) and (None, 3) are incompatible
解决方案
推荐阅读
- pycharm - .rc 文件的 PyCharm 文件类型关联(语法高亮)
- c++ - 如何在合并排序中计算比较和交换(反转)?[C++]
- swift - 当应用程序在后台时扫描外围设备(例如屏幕锁定时)
- html - 在小视口中更改导航栏链接的悬停颜色(在汉堡包下)
- pytorch - 为什么训练 10 个 epoch 后我的准确率没有提高?
- rust - 如何为 openssl::ssl::SslStream 创建一个 BufReader?
- json - 如何使用嵌套 JSON 对象从 API 响应导入 Google 表格
- python - 使用 Django 通用详细视图从 url 请求传递值
- javascript - 使用复杂/动态标准(升序和降序)对包含对象的数组进行排序?
- echo - 如何从 AIX 中的“echo $-”输出中区分交互式和非交互式登录?