python - 如何使用与模型输入形状兼容的 tensorflow.data.experimental.CsvDataset 创建小批量?
问题描述
我将通过tensorflow.data.experimental.CsvDataset
在 TensorFlow 2 中使用来训练小批量。但张量的形状不适合我模型的输入形状。
请让我知道通过 TensorFlow 数据集进行小批量训练的最佳方法是什么。
我尝试如下:
# I have a dataset with 4 features and 1 label
feature = tf.data.experimental.CsvDataset(['C:/data/iris_0.csv'], record_defaults=[.0] * 4, header=True, select_cols=[0,1,2,3])
label = tf.data.experimental.CsvDataset(['C:/data/iris_0.csv'], record_defaults=[.0] * 1, header=True, select_cols=[4])
dataset = tf.data.Dataset.zip((feature, label))
# and I try to minibatch training:
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(4,))])
model.compile(loss='mse', optimizer='sgd')
model.fit(dataset.repeat(1).batch(3), epochs=1)
我收到一个错误:
ValueError:检查输入时出错:预期dense_6_input的形状为(4,)但得到的数组形状为(1,)
因为 :CsvDataset()
返回一个形状的张量(features, batch)
,但我需要它是形状的(batch, features)
。
参考代码:
for feature, label in dataset.repeat(1).batch(3).take(1):
print(feature)
# (<tf.Tensor: id=487, shape=(3,), dtype=float32, numpy=array([5.1, 4.9, 4.7], dtype=float32)>, <tf.Tensor: id=488, shape=(3,), dtype=float32, numpy=array([3.5, 3. , 3.2], dtype=float32)>, <tf.Tensor: id=489, shape=(3,), dtype=float32, numpy=array([1.4, 1.4, 1.3], dtype=float32)>, <tf.Tensor: id=490, shape=(3,), dtype=float32, numpy=array([0.2, 0.2, 0.2], dtype=float32)>)
解决方案
创建一个数据集,其中数据集的tf.data.experimental.CsvDataset
每个元素对应于 CSV 文件中的一行,并由多个张量组成,即每列都有一个单独的张量。因此,首先您需要使用map
数据集的方法将所有这些张量堆叠成一个张量,以便它与模型期望的输入形状兼容:
def map_func(features, label):
return tf.stack(features, axis=1), tf.stack(label, axis=1)
dataset = dataset.map(map_func).batch(BATCH_SIZE)
推荐阅读
- velo - 在 wix 中管理个人资料成员数据集,以在单击其他类别的电子邮件时添加
- php - 使用 docker 时应在何处或何时调用 php 框架?
- c# - 在 BlockingCollection 中搜索特定元素
- unity3d - 使用 Mirror unity 的多人游戏 UI
- react-native - 有天赋的聊天 renderAction 道具不会调用函数
- elasticsearch - ECK 吊舱的容差
- java - 构建期间发生 Eclipse 错误
- vue.js - VueRouter 高级模式匹配
- javascript - 有没有办法我可以选择 Material ui 中的 Typography 上呈现的值并将其分配给一个状态?
- android-jetpack-compose - 在 Jetpack Compose 中使 TextField 可滚动