python - 如何使用 keras 将张量提供给计算图中的预训练模型?
问题描述
我想GAN
在生成器的末尾用一些确定性约束训练一个特定的条件,为此Keras
我需要首先计算生成器输出的嵌入VGG-16 pre-trained model
。
我正在使用python 3.6
.
在我的计算图中,我想将我的生成器输出img
提供给预训练的 VGG-16 模型,以便获得嵌入。
img
因为我在计算图中,所以我是一个形状张量 (None,224,224,3) 。问题是,如果我编译以下内容,我会收到错误
当向模型提供符号张量时,我们希望张量具有静态批量大小。得到具有形状的张量:(None, 224, 224, 3)
self.vgg = self.build_vgg()
def build_vgg(self):
vgg16_model = keras.applications.vgg16.VGG16()
return Model(inputs=vgg16_model.input,outputs=vgg16_model.get_layer('fc2').output)
#-------------------------------
# Construct Computational Graph
# for Generator
#-------------------------------
# For the generator we freeze the critic's layers
self.critic.trainable = False
self.generator.trainable = True
self.vgg.trainable = False
# Sampled noise for input to generator
noise = Input(shape=(self.latent_dim,))
# Input Embedding:
embedding = Input(shape=(self.embedding,))
# Generate images based of noise
img = self.generator([noise,embedding])
# Discriminator determines validity
valid = self.critic(img)
# Get the embeddings from vgg-16:
X = self.vgg.predict(img)
显然,我不能沿着第一个轴循环,因为它是无索引。我尝试使用 tensorflow 函数'tf.map_fn'将函数应用于此 'img' 张量,如下所示:
def Embedding(self,img):
fn = lambda x: self.vgg.predict(preprocess_input(np.expand_dims(x, axis=0))).flatten()
embedding = tf.map_fn(fn,img,dtype=tf.float32)
return embedding
#-------------------------------
# Construct Computational Graph
# for Generator
#-------------------------------
# For the generator we freeze the critic's layers
self.critic.trainable = False
self.generator.trainable = True
self.vgg.trainable = False
# Sampled noise for input to generator
noise = Input(shape=(self.latent_dim,))
# Input Embedding:
embedding = Input(shape=(self.embedding,))
# Generate images based of noise
img = self.generator([noise,embedding])
# Discriminator determines validity
valid = self.critic(img)
# Get the embeddings from VGG16
X = self.Embedding(img)
但我收到以下错误:
ValueError:使用序列设置数组元素。
回顾一下,我pre-trained VGG-16 model
想tensor
在Keras
. 我之前向你解释的是我已经尝试过的......
有人对此有什么建议吗?
解决方案
推荐阅读
- debugging - 如何跟随分叉,但在 gdb 中的 exec 上分离
- javascript - 我可以将我现有的 vue 组件用作单个文件组件吗
- python - 如何为 python 聊天程序提供 SSL 加密?
- sql - PostgreSQL:结合这两个查询
- rust - 如何通过 gRPC 函数使用 Hyper 客户端?
- swift - 使用childByAutoId方法时如何防止firebase数据库创建多个ID
- ruby - 有没有办法强制捆绑器接受特定版本的 gem?
- javascript - 显示/隐藏文本输入边框的 Javascript
- r - 为惩罚模型写一个循环
- python - 更改 Scrapy 下载图像名称