python-3.x - Model.fit 中自定义生成器的 Tensorflow Keras 问题
问题描述
我创建了一个自定义生成器tensorflow.keras.utils.Sequence
并尝试拟合一个简单的模型。
import tensorflow as tf
import keras
import numpy as np
from tensorflow.keras.utils import Sequence
from keras import Sequential
from keras.layers import InputLayer, Dense
class MyDataGenerator(Sequence):
def __init__(self, df, x_col='filename', y_col='class',
batch_size=32, path='./', num_classes=None, shuffle=True,
dim=(634,513,1), nfft=1024, hstep=256, sr=16000):
self.batch_size = batch_size
self.df = df
self.indices = self.df.index.tolist()
self.num_classes = num_classes
self.path = path
self.shuffle = shuffle
self.x_col = x_col
self.y_col = y_col
self.on_epoch_end()
self.dim = dim
self.nfft = nfft
self.hstep = hstep
self.sr = sr
def __getitem__(self, index):
index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
batch = [self.indices[k] for k in index]
X = np.zeros((self.batch_size, *self.dim), dtype=np.float32)
y = np.zeros((self.batch_size,), dtype=np.uint32)
return X, y
def __len__(self):
return math.ceil(len(self.indices) / self.batch_size)
def on_epoch_end(self):
self.index = np.arange(len(self.indices))
if self.shuffle == True:
np.random.shuffle(self.index)
train_datagen = MyDataGenerator(df_training)
valid_datagen = MyDataGenerator(df_validation, shuffle=False)
x,y = train_datagen[27]
print(x.shape, y.shape, x.dtype, y.dtype, type(x), type(y))
print(y)
1. (32, 634, 513, 1) (32,) float32 uint32 <class 'numpy.ndarray'> <class 'numpy.ndarray'>
2. [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
model = Sequential(name='Test_model')
model.add(InputLayer((634, 513, 1), name='Input'))
model.add(Flatten(name='Flatten'))
model.add(Dense(1, activation='sigmoid', name='Output'))
model.summary()
model.compile(optimizer='adam', loss='binary_crossentropy',
metrics=['accuracy'])
model.fit(
train_datagen,
epochs=1,
verbose=0,
validation_data=valid_datagen
)
我得到了ValueError
Model: "Test_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Flatten (Flatten) (None, 325242) 0
_________________________________________________________________
Output (Dense) (None, 1) 325243
=================================================================
Total params: 325,243
Trainable params: 325,243
Non-trainable params: 0
_________________________________________________________________
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-203-533c619e3ae0> in <module>()
11 epochs=1,
12 verbose=0,
---> 13 validation_data=valid_datagen
14 )
3 frames
/usr/local/lib/python3.7/dist-packages/keras/engine/data_adapter.py in select_data_adapter(x, y)
976 "Failed to find data adapter that can handle "
977 "input: {}, {}".format(
--> 978 _type_name(x), _type_name(y)))
979 elif len(adapter_cls) > 1:
980 raise RuntimeError(
ValueError: Failed to find data adapter that can handle input: <class '__main__.MyDataGenerator'>, <class 'NoneType'>
无法弄清楚为什么模型不想使用生成器进行训练。数据类型正确。输入和输出也是。
如果尝试在没有生成器的情况下执行此操作,则模型可以正常工作,但由于数据集庞大,无法删除生成器。
X = np.zeros((32, 634, 513, 1), dtype=np.float32)
y = np.zeros((32,), dtype=np.uint32)
model = Sequential(name='Test_model')
model.add(InputLayer((634, 513, 1), name='Input'))
model.add(Flatten(name='Flatten'))
model.add(Dense(1, activation='sigmoid', name='Output'))
model.summary()
model.compile(optimizer='adam', loss='binary_crossentropy',
metrics=['accuracy'])
model.fit(
X,y,
epochs=1,
verbose=1,
#validation_data=valid_datagen
)
Model: "Test_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Flatten (Flatten) (None, 325242) 0
_________________________________________________________________
Output (Dense) (None, 1) 325243
=================================================================
Total params: 325,243
Trainable params: 325,243
Non-trainable params: 0
_________________________________________________________________
1/1 [==============================] - 14s 14s/step - loss: 0.6931 - accuracy: 1.0000
<keras.callbacks.History at 0x7f74adf34510>
张量流版本:2.5.0
Keras 版本:2.4.3
解决方案
推荐阅读
- xslt - 如何在 XSLT Oxygen 中使用 Key
- javascript - 我想问一下有没有办法缩短这个三元语句?
- apache-kafka - kafka批处理、ack粒度和死信队列
- cakephp - 在 cakephp4 中如何访问模型中的模型
- javascript - 尝试使用 react js 创建一个多功能按钮 {在第一次点击时 - 保存数据,打开“查看”,然后在第二次点击时 - 导航到保存的页面}
- wordpress - 隐藏“简单产品”类型选项的“属性”选项卡
- scikit-learn - sklearn GaussianMixture on Images
- reactjs - 需要 React 中的循环帮助
- java - 广播接收器无法正常工作
- java - JMSItemReader-Spring Batch的批量消息消费