python - [Tensorflow 2]如何为具有不一致形状的数据的多输入多输出模型构建数据输入管道
问题描述
我正在使用Tensorflow 2,我需要构建一个多输入多输出模型,我的数据是时间序列数据,它的时间维度没有一致的形状。我尝试了很多方法,但由于形状不一致,都没有奏效。
共有三个数据,其中一个被使用了两次。它们的格式为(number of files, None, 5)
,None
维度为不一致维度。
这是一些重现我的问题的测试代码,在这种情况下我使用的是生成器,但可以随意更改为任何方法。有人可以帮我处理这个输入管道吗?
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
dummy_1 = [[[1.1,2,3,4,5],[2,3,4,5,6],[3,4,5,6,7]],
[[1.2,2,3,4,5],[2,3,4,5,6.8]],
[[1.3,2,3,4,5],[2,3,4,5,6],[3,4,5,6,7],[4,5,6,7,8.9]]]
dummy_2 = [[[1.1,2,3,4,5],[2,3,4,5,6]],
[[1.1,2,3,4,5],[2,3,4,5,6]],[3,4,5,6,7],
[[1.3,2,3,4,5],[2,3,4,5,6]]]
dummy_3 = [[[1.5,2,3,4,5],[2,3,4,5,6]],
[[1.6,2,3,4,5],[2,3,4,5,6]],[3,4,5,6,7],
[[1.7,2,3,4,5],[2,3,4,5,6]]]
def gen():
for i in range(len(dummy_1)):
yield(dummy_1[i],dummy_2[i],dummy_2[i],dummy_3[i])
def custom_loss(y_true, y_pred):
return tf.reduce_mean(tf.abs(y_pred - y_true))
class network():
def __init__(self):
input_1 = keras.Input(shape=(None,5))
input_2 = keras.Input(shape=(None,5))
output_1 = layers.Conv1DTranspose(16, 3, padding='same', activation='relu')(input_1)
output_2 = layers.Conv1DTranspose(16, 3, padding='same', activation='relu')(input_2)
self.model = keras.Model(inputs=[input_1, input_2],
outputs=[output_1, output_2])
# compile model
self.model.compile(optimizer=keras.optimizers.SGD(learning_rate=0.001),
loss={"mel_loss":custom_loss, "mag_loss":custom_loss})
def train(self):
self.dataset = tf.data.Dataset.from_generator(gen,
(tf.float32, tf.float32, tf.float32, tf.float32))
self.dataset.batch(32).repeat()
self.model.fit(self.dataset,epochs=3)
#self.model.fit([dummy_1, dummy_2],
# [dummy_2, dummy_3],
# epochs=3)
net = network()
net.train()
解决方案
这对于 TF2 目前是不可能的,参考https://github.com/tensorflow/tensorflow/issues/45112
推荐阅读
- python - Raspberry Pi 相机:在安装的 USB 记忆棒上保存时出现 PermissionError
- django - 在elpy(Django命令)中从默认python解释器切换到ipython
- tfs - 无法访问工作项跟踪服务 Azure DevOps Extensions
- javascript - Angular. Get server response in catchError
- swift - Swift 中带有闭包的模棱两可的方法重载,但仅当闭包返回一个值时
- filter - ffmpeg 剪切视频 + 重新缩放预览(代理)
- groovy - (Groovy) 查找所有包含特定用户组的 Confluence 空间
- hashicorp-vault - Vault 数据库机密引擎忽略非默认端口
- powershell - 在 foreach 中使用演员表
- javascript - 在 Wkwebview Cordova ios 14 中打开 Websql 数据库