python - 如何在没有迭代的情况下在 Keras 中训练多输出模型时加载数据?
问题描述
我在 TensorFlow 2 中有一个带有 1 个输入和 2 个输出的 Keras 模型。调用时model.fit
我想传递数据集x=train_dataset
并调用model.fit
一次。train_dataset
是用tf.data.Dataset.from_generator
它产生的:x1,y1,y2。
我可以进行培训的唯一方法是:
for x1, y1,y2 in train_dataset:
model.fit(x=x1, y=[y1,y2],...)
如何告诉 TensorFlow 解包变量并在没有显式for
循环的情况下进行训练?使用for
循环使许多事情变得不那么实用,以及train_on_batch
.
如果我想运行model.fit(train_dataset, ...)
该函数不了解什么x
和y
是,即使模型定义如下:
model = Model(name ='Joined_Model',inputs=self.x, outputs=[self.network.y1, self.network.y2])
它会抛出一个错误,即在获得 1 时期望 2 个目标,即使数据集有 3 个变量,也可以在循环中迭代。
数据集和小批量生成为:
def dataset_joined(self, n_epochs, buffer_size=32):
dataset = tf.data.Dataset.from_generator(
self.mbatch_gen_joined,
(tf.float32, tf.float32,tf.int32),
(tf.TensorShape([None, None, self.n_feat]),
tf.TensorShape([None, None, self.n_feat]),
tf.TensorShape([None, None])),
[tf.constant(n_epochs)]
)
dataset = dataset.prefetch(buffer_size)
return dataset
def mbatch_gen_joined(self, n_epochs):
for _ in range(n_epochs):
random.shuffle(self.train_s_list)
start_idx, end_idx = 0, self.mbatch_size
for _ in range(self.n_iter):
s_mbatch_list = self.train_s_list[start_idx:end_idx]
d_mbatch_list = random.sample(self.train_d_list, end_idx-start_idx)
s_mbatch, d_mbatch, s_mbatch_len, d_mbatch_len, snr_mbatch, label_mbatch, _ = \
self.wav_batch(s_mbatch_list, d_mbatch_list)
x_STMS_mbatch, xi_bar_mbatch, _ = \
self.training_example(s_mbatch, d_mbatch, s_mbatch_len,
d_mbatch_len, snr_mbatch)
#seq_mask_mbatch = tf.cast(tf.sequence_mask(n_frames_mbatch), tf.float32)
start_idx += self.mbatch_size; end_idx += self.mbatch_size
if end_idx > self.n_examples: end_idx = self.n_examples
yield x_STMS_mbatch, xi_bar_mbatch, label_mbatch
解决方案
Keras 模型期望 Python 生成器或对象以(or )tf.data.Dataset
格式的元组形式提供输入数据。如果模型有多个输入/输出层,每个或应该是一个列表/元组。因此,在您的代码中,生成的数据也应该与这种预期格式兼容:(input_data, target_data)
(input_data, target_data, sample_weights)
input_data
target_data
yield x_STMS_mbatch, (xi_bar_mbatch, label_mbatch) # <- the second element is a tuple itself
此外,这也应该在传递给from_generator
方法的参数中考虑:
dataset = tf.data.Dataset.from_generator(
self.mbatch_gen_joined,
output_types=(
tf.float32,
(tf.float32, tf.int32)
),
output_shapes=(
tf.TensorShape([None, None, self.n_feat]),
(
tf.TensorShape([None, None, self.n_feat]),
tf.TensorShape([None, None])
)
),
args=(tf.constant(n_epochs),)
)
推荐阅读
- javascript - 使用 lodash 压缩两个对象,其中一个对象的键和另一个对象的值
- openapi - 将 mTLS Cloudflare 添加到 OpenAPI 3.0.0 yaml 文件
- r - 创建每月虚拟变量
- javascript - Dompdf 无法正确排列 html 元素
- javascript - 过滤器中的输入范围在移动设备上不起作用
- r - 如何找到 LISA.R 以在 R 中绘制 LISA 地图?
- version-control - 在 Perforce 中使用类型映射编辑更新所有文件
- javascript - Firebase & 网站关系信息争夺
- html - 在背景图像 css 上覆盖身体
- node.js - 您如何在 TypeORM 实体中定义多态关系?