首页 > 解决方案 > 如何使用 keras 将张量提供给计算图中的预训练模型?

问题描述

我想GAN在生成器的末尾用一些确定性约束训练一个特定的条件,为此Keras我需要首先计算生成器输出的嵌入VGG-16 pre-trained model

我正在使用python 3.6.

在我的计算图中,我想将我的生成器输出img提供给预训练的 VGG-16 模型,以便获得嵌入。

img因为我在计算图中,所以我是一个形状张量 (None,224,224,3) 。问题是,如果我编译以下内容,我会收到错误

当向模型提供符号张量时,我们希望张量具有静态批量大小。得到具有形状的张量:(None, 224, 224, 3)

self.vgg = self.build_vgg()

def build_vgg(self):
    vgg16_model = keras.applications.vgg16.VGG16()
    return Model(inputs=vgg16_model.input,outputs=vgg16_model.get_layer('fc2').output)

        #-------------------------------
    # Construct Computational Graph
    #         for Generator
    #-------------------------------

    # For the generator we freeze the critic's layers
    self.critic.trainable = False
    self.generator.trainable = True
    self.vgg.trainable = False


    # Sampled noise for input to generator
    noise = Input(shape=(self.latent_dim,))

    # Input Embedding:
    embedding = Input(shape=(self.embedding,))


    # Generate images based of noise

    img = self.generator([noise,embedding])

    # Discriminator determines validity

    valid = self.critic(img)

    # Get the embeddings from vgg-16:
    X = self.vgg.predict(img)

显然,我不能沿着第一个轴循环,因为它是无索引。我尝试使用 tensorflow 函数'tf.map_fn'将函数应用于此 'img' 张量,如下所示:

    def Embedding(self,img):
    fn = lambda x: self.vgg.predict(preprocess_input(np.expand_dims(x, axis=0))).flatten()
    embedding = tf.map_fn(fn,img,dtype=tf.float32)
    return embedding

        #-------------------------------
    # Construct Computational Graph
    #         for Generator
    #-------------------------------

    # For the generator we freeze the critic's layers
    self.critic.trainable = False
    self.generator.trainable = True
    self.vgg.trainable = False

    # Sampled noise for input to generator
    noise = Input(shape=(self.latent_dim,))

    # Input Embedding:
    embedding = Input(shape=(self.embedding,))


    # Generate images based of noise

    img = self.generator([noise,embedding])

    # Discriminator determines validity

    valid = self.critic(img)

    # Get the embeddings from VGG16
    X = self.Embedding(img)

但我收到以下错误:

ValueError:使用序列设置数组元素。

回顾一下,我pre-trained VGG-16 modeltensorKeras. 我之前向你解释的是我已经尝试过的......

有人对此有什么建议吗?

标签: pythontensorflowkeras

解决方案


推荐阅读