python - 如何在 Keras 中实现包括 GAN 生成器的自定义损失函数?
问题描述
我想为基于 Encoder-Generator 训练的真实输入图像使用 PGGAN 生成器找到类似的图像。下面是我的实现:
# load pre-trained generator
sess = tf.InteractiveSession()
with open('network-snapshot-final.pkl', 'rb') as file:
G, D, Gs = pickle.load(file)
# network parameters
image_size = 1024
input_shape = (image_size, image_size, 1)
batch_size = 8
kernel_size = 3
filters = 16
latent_dim = 512
epochs = 100
# build an encoder
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
for i in range(10):
filters *= 2
x = Conv2D(filters=filters,
kernel_size=kernel_size,
activation='relu',
strides=2,
padding='same')(x)
# generate latent vector
x = Flatten()(x)
x = Dense(2048, activation='relu')(x)
z_sim = Dense(latent_dim, name='z_sim')(x)
encoder = Model(inputs, z_sim, name='encoder')
# define a custom loss function
def loss_enc(x, z_sim):
im_g = tf.convert_to_tensor(Gs.run(z_sim.eval(), labels))
im_g2 = tf.reshape(im_g, [-1, 1024, 1024, 1])
los = mse(K.flatten(x), K.flatten(im_g2))
return los
编译模型后,遇到如下错误信息:
encoder.compile(optimizer='rmsprop', loss=loss_enc)
InvalidArgumentError:您必须为占位符张量“encoder_input_19”提供一个值,其 dtype 浮点数和形状 [?,1024,1024,1] [[{{node encoder_input_19}} = Placeholderdtype=DT_FLOAT, shape= [?,1024,1024,1 ], _device="/job:localhost/replica:0/task:0/device:GPU:0"]] [[{{node z_sim_12/BiasAdd/_713}} = _Recvclient_terminated=false, recv_device="/job:localhost /replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_127_z_sim_12/BiasAdd" , tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
为此,我该如何正确实现损失函数?
解决方案
首先:
def loss_enc(x, z_sim):
def loss(y_pred, y_true):
# Things you would do with x, z_sim and store in 'result' (for example)
return result
return loss
编译模型时:
encoder.compile(optimizer='rmsprop', loss=loss_enc(x, z_sim))
推荐阅读
- c# - 需要从 URL 读取流,更改 URI 中的“页面”属性,直到它返回“未找到记录”字符串
- model-view-controller - Redirect in MVC controller to another controller not working
- visual-studio-code - 如何在 VS Code 中运行 Lua 脚本
- ruby-on-rails - Edgeguides 的 Rails 入门模板缺失错误
- ruby - 如何在多个配置文件中包含自定义检查资源?
- python - 雪碧没有出现在pygame窗口上
- python - 如何在使用 Python 和 Redis 以管理员身份登录时创建注销用户的函数
- javascript - 将 jQuery 转换为 JavaScript 以进行 Google Analytics pdf 下载跟踪
- asp.net-core - 使用 EntityFrameworkCore 向 ASP.NET Core 项目添加标识时出现问题
- python - 如何在 Django 中使用 Hardcopy 在所选目录中创建 pdf 输出文件?