python - model.fit() 不接受 tf.data.Dataset 的输入形状
问题描述
我想通过应用tf.data.Dataset
.
检查 TF 2.0 的文档后,我发现该.fit()
函数(https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit)接受:
x - 一个 tf.data 数据集。应返回 (inputs, targets) 或 (inputs, targets, sample_weights) 的元组。
因此,我编写了以下小型概念验证代码:
from sklearn.datasets import make_blobs
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.metrics import Accuracy, AUC
X, Y = make_blobs(n_samples=500, n_features=2, cluster_std=3.0, random_state=1)
def define_model():
model = Sequential()
model.add(Dense(units=1, activation="sigmoid", input_shape=(2,)))
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=[AUC(), Accuracy()])
return model
model = define_model()
X_ds = tf.data.Dataset.from_tensor_slices(X)
Y_ds = tf.data.Dataset.from_tensor_slices(Y)
dataset = tf.data.Dataset.zip((X_ds, Y_ds))
for elem in dataset.take(1):
print(type(elem))
print(elem)
model.fit(x=dataset) #<-- does not work
#model.fit(x=X, y=Y) <-- does work without any problems....
正如第二条评论中提到的,不应用 a 的代码可以tf.data.Dataset
正常工作。
但是,在应用 Dataset 对象时,我收到以下错误消息:
<class 'tuple'>
(<tf.Tensor: shape=(2,), dtype=float64, numpy=array([-10.42729974, -0.85439721])>, <tf.Tensor: shape=(), dtype=int64, numpy=1>)
... other output here...
ValueError: Error when checking input: expected dense_19_input to have
shape (2,) but got array with shape (1,)
根据我对文档的理解,我构建的数据集应该正是 fit 方法所期望的元组对象。
我不明白这个错误信息。
我在这里做错了什么?
解决方案
当您将数据集传递给 时fit
,预计它将直接生成批次,而不是单个示例。您只需要在训练之前对数据集进行批处理。
dataset = dataset.batch(batch_size)
model.fit(x=dataset)
推荐阅读
- c++ - QPushButton 更改布局中小部件的大小
- mongodb - Mongoose 过滤子文档
- python - 我收到此错误 AttributeError: 'function' object has no attribute 'hlauncher' 同时尝试从另一个文件获取属性
- r - 将向量转换为 R 中的密度向量
- python - 在 matplotlib 中选择要绘制的子图
- css - 溢出:隐藏在 Div 中无法正常工作
- python - 对除每 n 次以外的所有项目进行切片
- postgresql - 如何在不锁定整个表的情况下强制 Postgres 返回合理的行数?
- wordpress - 带有 Wordpress 的 Google Cloud VM 实例反复崩溃
- java - JDK 11.0.2 编译失败,javac NPE 在匿名参数化类类型推断上