python - 验证数据可以成为 tensorflow.keras 2.0 中的生成器吗?
问题描述
在tensorflow.keras的官方文档中,
validation_data 可以是:Numpy 数组的元组 (x_val, y_val) 或 Numpy 数组数据集的张量元组 (x_val, y_val, val_sample_weights) 对于前两种情况,必须提供 batch_size。对于最后一种情况,可以提供validation_steps。
它没有提到生成器是否可以充当验证数据。所以我想知道validation_data 是否可以是数据生成器?像下面的代码:
net.fit_generator(train_it.generator(), epoch_iterations * batch_size, nb_epoch=nb_epoch, verbose=1,
validation_data=val_it.generator(), nb_val_samples=3,
callbacks=[checker, tb, stopper, saver])
更新:在keras的官方文档中,内容相同,但增加了一个句子:
- 数据集或数据集迭代器
考虑到
dataset 对于前两种情况,必须提供 batch_size。对于最后一种情况,可以提供validation_steps。
我认为应该有3种情况。Keras 的文件是正确的。所以我会在 tensorflow.keras 中发布一个问题来更新文档。
解决方案
是的,它可以,奇怪的是它不在文档中,但它的工作方式与x
论点完全一样,您也可以使用 akeras.Sequence
或 a generator
。在我的项目中,我经常使用keras.Sequence
它就像一个生成器
显示它有效的最小工作示例:
import numpy as np
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten
def generator(batch_size): # Create empty arrays to contain batch of features and labels
batch_features = np.zeros((batch_size, 1000))
batch_labels = np.zeros((batch_size,1))
while True:
for i in range(batch_size):
yield batch_features, batch_labels
model = Sequential()
model.add(Dense(125, input_shape=(1000,), activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
train_generator = generator(64)
validation_generator = generator(64)
model.fit(train_generator, validation_data=validation_generator, validation_steps=100, epochs=100, steps_per_epoch=100)
100/100 [==============================] - 1s 13ms/步 - 损失:0.6689 - 准确度:1.0000 - val_loss : 0.6448 - val_accuracy: 1.0000 Epoch 2/100 100/100 [==============================] - 0s 4ms/step - 损失:0.6223 - 准确度:1.0000 - val_loss:0.6000 - val_accuracy:1.0000 纪元 3/100 100/100 [========================= ====] - 0s 4ms/步 - 损失:0.5792 - 准确度:1.0000 - val_loss:0.5586 - val_accuracy: 1.0000 Epoch 4/100 100/100 [================ ==============] - 0s 4ms/步 - 损失:0.5393 - 准确度:1.0000 - val_loss:0.5203 - val_accuracy:1.0000
推荐阅读
- c# - 如何指定 Roslyn 方法参数
- c# - 如何检查ffmpeg何时完成任务?
- python - ImportError: No module named xxx, 不管是什么包
- javascript - 如何在 Nodejs 中集群 - 只是本地计算
- system-verilog - 获取所有 OVM 组件句柄的 API
- javascript - 使用 Ajax JSON 将 Map 发送到 servlet
- r - Debian 上的 CRAN 软件包错误。可以在 Windows 中查看照片,而不是在 Debian 中
- javascript - 如何检查 string.endsWith 并忽略空格
- python - python 3表现webdriver - OOP传递驱动程序
- python - Python将图像存储在列表图像数据中无法转换为浮点数