python - 在 keras 训练期间使用 fit_generator 的问题
问题描述
我正在处理非常大的文本数据集。我考虑过使用model.fit_generator
method 而不是 simple model.fit
,所以我尝试使用这个生成器:
def TrainGenerator(inp, out):
for i,o in zip(inp, out):
yield i,o
当我在训练期间尝试使用它时,使用:
#inp_train, out_train are lists of sequences padded to 50 tokens
model.fit_generator(generator = TrainGenerator(inp_train,out_train),
steps_per_epoch = BATCH_SIZE * 100,
epochs = 20,
use_multiprocessing = True)
我得到:
ValueError: Error when checking input: expected embedding_input to have shape (50,) but got array with shape (1,)
现在,我尝试使用简单的model.fit
方法,效果很好。所以,我认为我的问题出在生成器上,但是由于我是使用生成器的新手,我不知道如何解决它。完整的模型摘要是:
Layer (type) Output Shape
===========================================
Embedding (Embedding) (None, 50, 400)
___________________________________________
Bi_LSTM_1 (Bidirectional) (None, 50, 1024)
___________________________________________
Bi_LSTM_2 (Bidirectional) (None, 50, 1024)
___________________________________________
Output (Dense) (None, 50, 153)
===========================================
#编辑 1
第一条评论以某种方式触发了我。我意识到我误解了发电机的工作原理。我的生成器的输出是一个形状为 50 的列表,而不是一个形状为 50 的 N 个列表的列表。所以我深入研究了 keras 文档并找到了这个。所以,我改变了我的工作方式,这是作为生成器工作的类:
class BatchGenerator(tf.keras.utils.Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return batch_x, to_categorical(batch_y,num_labels)
to_categorical
函数在哪里:
def to_categorical(sequences, categories):
cat_sequences = []
for s in sequences:
cats = []
for item in s:
cats.append(np.zeros(categories))
cats[-1][item] = 1.0
cat_sequences.append(cats)
return np.array(cat_sequences)
所以,我现在注意到的是我的网络的良好性能提升,每个时期现在持续一半。发生这种情况是因为我有更多的可用 RAM,因为现在我没有将所有数据集加载到内存中吗?
解决方案
推荐阅读
- matrix - 如何用 CUDA 计算大矩阵的二维 FFT?
- scala - 如何延迟创建一个虚拟的akka Httpresponse?
- paypal - PayPal 快速结帐添加税务信息
- python - 数据输入期间的 Azure ML 时间序列模型推理错误(python)
- haskell - emacs haskell-mode 与 cabal 项目。“`--ghc-option=ferror-spans` 的目标语法无法识别。”
- ios - 如何根据带有日期的对象数组创建多个部分?
- python - 如何在logcat上查看python应用程序日志
- c# - Windows 服务给出错误 System.IO.DirectoryNotFoundException
- javascript - 以与 html 代码中相同格式的文本打印 Javascript 文本
- java - 如何从文本字段中删除焦点并将其设置在按钮上?