python - 如何在 TensorFlow 2 中迭代多个数据集
问题描述
我使用TensorFlow 2.2.0。在我的数据管道中,我使用多个数据集来训练神经网络。就像是:
# these are all tf.data.Dataset objects:
paired_data = get_dataset(id=0, repeat=False, shuffle=True)
unpaired_images = get_dataset(id=1, repeat=True, shuffle=True)
unpaired_masks = get_dataset(id=2, repeat=True, shuffle=True)
在训练循环中,我想迭代paired_data
定义一个时期。但我也想迭代unpaired_images
并unpaired_masks
优化其他目标(用于语义分割的经典半监督学习,带有掩码鉴别器)。
为了做到这一点,我当前的代码如下所示:
def train_one_epoch(self, writer, step, paired_data, unpaired_images, unpaired_masks):
unpaired_images = unpaired_images.as_numpy_iterator()
unpaired_masks = unpaired_masks.as_numpy_iterator()
for images, labels in paired_data:
with tf.GradientTape() as sup_tape, \
tf.GradientTape() as gen_tape, \
tf.GradientTape() as disc_tape:
# paired data (supervised cost):
predictions = segmentor(images, training=True)
sup_loss = weighted_cross_entropy(predictions, labels)
# unpaired data (adversarial cost):
pred_real = discriminator(next(unpaired_masks), training=True)
pred_fake = discriminator(segmentor(next(unpaired_images), training=True), training=True)
gen_loss = generator_loss(pred_fake)
disc_loss = discriminator_loss(pred_real, pred_fake)
gradients = sup_tape.gradient(sup_loss, self.segmentor.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients, self.segmentor.trainable_variables))
gradients = gen_tape.gradient(gen_loss, self.segmentor.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients, self.segmentor.trainable_variables))
gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(gradients, self.discriminator.trainable_variables))
但是,这会导致错误:
main.py:275 train_one_epoch *
unpaired_images = unpaired_images.as_numpy_iterator()
/home/venvs/conda/miniconda3/envs/tf-gpu/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py:476 as_numpy_iterator **
raise RuntimeError("as_numpy_iterator() is not supported while tracing "
RuntimeError: as_numpy_iterator() is not supported while tracing functions
知道这有什么问题吗?这是在 tensorflow 2 中优化多个损失/数据集的正确方法吗?
我在评论中添加了我当前对问题的解决方案。任何关于更优化方式的建议都非常受欢迎!:)
解决方案
我目前的解决方案:
def train_one_epoch(self, writer, step, paired_data, unpaired_images, unpaired_masks):
# create a new dataset zipping the three original dataset objects
dataset = tf.data.Dataset.zip((paired_data, unpaired_images, unpaired_masks))
dataset = dataset.batch(1)
for (images, labels), unpaired_images, unpaired_masks in dataset:
# access the elements as the first and only element of the batched dataset
images, labels, unpaired_images, unpaired_masks = \
images[0], labels[0], unpaired_images[0], unpaired_masks[0]
# go ahead and train:
with tf.GradientTape() as tape:
#[...]
推荐阅读
- java - GET 映射有效,但 POST、PUT 和 DELETE 的状态为 405(在 spring-boot Restful api 中)
- python - 如何将一列添加到具有不同值的多个 .csv 文件
- bootstrap-4 - 有没有boostrap会影响对齐?
- sharepoint-online - 如何在线获取 SharePoint 的所有主题并设置特定主题
- ruby - 如何使用 capybara 查看本地存储和会话存储
- c# - 线程池 - 创建了 10 个线程?
- angular - 无法从 EJ2 多选下拉列表中获取所选记录
- c# - 在共享主机上发布使用 .net core 2.2 构建的网站
- android - Gradle 查看错误的 Maven 存储库
- vba - 类型不匹配尝试在集合中的对象中设置数据