首页 > 解决方案 > [Tensorflow 2]如何为具有不一致形状的数据的多输入多输出模型构建数据输入管道

问题描述

我正在使用Tensorflow 2,我需要构建一个多输入多输出模型,我的数据是时间序列数据,它的时间维度没有一致的形状。我尝试了很多方法,但由于形状不一致,都没有奏效。

共有三个数据,其中一个被使用了两次。它们的格式为(number of files, None, 5),None维度为不一致维度。

这是一些重现我的问题的测试代码,在这种情况下我使用的是生成器,但可以随意更改为任何方法。有人可以帮我处理这个输入管道吗?

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

dummy_1 = [[[1.1,2,3,4,5],[2,3,4,5,6],[3,4,5,6,7]],
           [[1.2,2,3,4,5],[2,3,4,5,6.8]],
           [[1.3,2,3,4,5],[2,3,4,5,6],[3,4,5,6,7],[4,5,6,7,8.9]]]

dummy_2 = [[[1.1,2,3,4,5],[2,3,4,5,6]],
           [[1.1,2,3,4,5],[2,3,4,5,6]],[3,4,5,6,7],
           [[1.3,2,3,4,5],[2,3,4,5,6]]]

dummy_3 = [[[1.5,2,3,4,5],[2,3,4,5,6]],
           [[1.6,2,3,4,5],[2,3,4,5,6]],[3,4,5,6,7],
           [[1.7,2,3,4,5],[2,3,4,5,6]]]

def gen():
    for i in range(len(dummy_1)):
        yield(dummy_1[i],dummy_2[i],dummy_2[i],dummy_3[i])

def custom_loss(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_pred - y_true))

class network():
    def __init__(self):
        input_1 = keras.Input(shape=(None,5))
        input_2 = keras.Input(shape=(None,5))
        output_1 = layers.Conv1DTranspose(16, 3, padding='same', activation='relu')(input_1)
        output_2 = layers.Conv1DTranspose(16, 3, padding='same', activation='relu')(input_2)
        
        self.model = keras.Model(inputs=[input_1, input_2],
                                 outputs=[output_1, output_2])
        
        # compile model
        self.model.compile(optimizer=keras.optimizers.SGD(learning_rate=0.001),
                           loss={"mel_loss":custom_loss, "mag_loss":custom_loss})
    
    def train(self):
        self.dataset = tf.data.Dataset.from_generator(gen, 
                                                      (tf.float32, tf.float32, tf.float32, tf.float32))
        self.dataset.batch(32).repeat()
        
        self.model.fit(self.dataset,epochs=3)
        #self.model.fit([dummy_1, dummy_2],
        #               [dummy_2, dummy_3],
        #               epochs=3)

net = network()
net.train()

标签: pythontensorflowkerastensorflow2.0

解决方案


这对于 TF2 目前是不可能的,参考https://github.com/tensorflow/tensorflow/issues/45112


推荐阅读