python - 如何同时使用 DataFrameIterator 和 TensorSliceDataset 进行混合数据的训练?
问题描述
我正在尝试为黑色素瘤比赛的 kaggle 笔记本上的混合数据输入拟合一个多分支神经网络。我正在关注本教程以制作多分支神经网络。我现在无法展示网络架构,因为竞争正在进行中。我的模型在编译时没有显示错误。但是,在使用以下代码片段拟合模型时,代码单元会引发ValueError:
opt = Adam(lr=1e-3, decay=1e-3 / 200)
model.compile(loss=tensorflow.keras.losses.BinaryCrossentropy(from_logits=True), optimizer=opt)
# train the model
print("[INFO] training model...")
model.fit(
x=[train_dataset, train_generator],
epochs=2, batch_size=32)
这里,train_dataset 是从方法获取的 BatchDataset 对象,train_generator 是从from_tensor_slices().shuffle().batch()
方法获取的 DataFrameIterator flow_from_dataframe()
。
回溯如下所示:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-77-2e58e85574a4> in <module>
5 model.fit(
6 x=[train_dataset, train_generator],
----> 7 epochs=2, batch_size=32)
8 # make predictions on the testing data
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
813 workers=workers,
814 use_multiprocessing=use_multiprocessing,
--> 815 model=self)
816
817 # Container that configures and calls `tf.keras.Callback`s.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weight, batch_size, steps_per_epoch, initial_epoch, epochs, shuffle, class_weight, max_queue_size, workers, use_multiprocessing, model)
1097 self._insufficient_data = False
1098
-> 1099 adapter_cls = select_data_adapter(x, y)
1100 self._adapter = adapter_cls(
1101 x,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/data_adapter.py in select_data_adapter(x, y)
961 "Failed to find data adapter that can handle "
962 "input: {}, {}".format(
--> 963 _type_name(x), _type_name(y)))
964 elif len(adapter_cls) > 1:
965 raise RuntimeError(
ValueError: Failed to find data adapter that can handle input: (<class 'list'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>", "<class 'keras_preprocessing.image.dataframe_iterator.DataFrameIterator'>"}), <class 'NoneType'>
有人可以帮我调试这个错误吗?
解决方案
没有答案。应该有针对 Keras 的功能请求。它们不支持数据集和数据框迭代器的混合输入类型。以下是 model.fit 的 TF 2.2 文档中的内容:“输入数据。它可能是:
A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs).
A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs).
A dict mapping input names to the corresponding array/tensors, if the model has named inputs.
A tf.data dataset. Should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights).
A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights). A more detailed description of unpacking behavior for iterator types (Dataset, generator, Sequence) is given below. "
推荐阅读
- html - 当我向我的徽标添加超链接时,它会忽略 CSS 大小
- javascript - SQL 数据库未更新 | JavaScript | 不和谐机器人
- r - 如果包含特定字符,则替换整个观察/字符串
- angular - 角度延迟加载在我的项目中不起作用
- node.js - 如何修复没有记录任何内容的快速应用程序
- indexing - 通过javascript函数更改淘汰赛foreach中特定项目的背景图像
- powerquery - 表中同一列乘以多列的递归表达式
- html - “溢出:隐藏”替代方法来切断 div 容器上方的元素
- git - 出现错误:将 WebApp 部署到 Azure 时出现 409 冲突
- python - 阿拉伯数字到罗马数字转换器:包含 4 的数字不会转换