keras - keras 拟合函数中的尺寸误差(keras 中的条件变分自动编码器)
问题描述
我正在尝试实现一个条件自动编码器,它非常简单,并且在使 fit 函数工作时出错。这是完整的代码片段
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
class cVAE(keras.Model):
def __init__(self,
original_dim,
label_dim,
latent_dim,
beta=1,
batch_size=1,
**kwargs):
super(cVAE, self).__init__(**kwargs)
self.original_dim = original_dim
self.latent_dim = latent_dim
self.label_dim = label_dim
self.beta = beta
self.batch_size = batch_size
# Build the encoder
print("building encoder")
rnaseq_inputs = keras.Input(shape=(self.original_dim, ),batch_size=self.batch_size)
label_inputs = keras.Input(shape=(self.label_dim, ),batch_size=self.batch_size)
encoder_inputs = layers.concatenate([rnaseq_inputs, label_inputs], name='concat_1')
z_mean = layers.Dense(self.latent_dim,
kernel_initializer = 'glorot_uniform')(encoder_inputs)
z_mean = layers.BatchNormalization()(z_mean)
z_mean = layers.Activation('relu')(z_mean)
z_log_var = layers.Dense(self.latent_dim,
kernel_initializer = 'glorot_uniform')(encoder_inputs)
z_log_var = layers.BatchNormalization()(z_log_var)
z_log_var = layers.Activation('relu')(z_log_var)
z = Sampling()([z_mean, z_log_var])
zc = layers.concatenate([z, label_inputs],name='concat_2')
self.encoder = keras.Model([rnaseq_inputs, label_inputs], [z_mean, z_log_var, z, zc])
print("building decoder")
# Build the decoder
decoder_input_dim = self.latent_dim + self.label_dim
decoder_output_dim = self.original_dim + self.label_dim
decoder_inputs = keras.Input(shape=(decoder_input_dim, ))
decoder_outputs = keras.layers.Dense(decoder_output_dim,
activation='sigmoid')(decoder_inputs)
self.decoder = keras.Model(decoder_inputs, decoder_outputs)
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
# exp_data, label_data = data
z_mean, z_log_var, z, zc = self.encoder(data)
reconstruction = self.decoder(zc)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.mean_squared_error(data, reconstruction)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
toy_data = np.random.random((100,100)).astype('float32')
label = np.random.randint(0,high=2,size=100).reshape(100,1).astype('float32')
cvae_model = cVAE(original_dim=100,batch_size=2,label_dim=1,latent_dim=1)
cvae_model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.003))
# fitting
cvae_model.fit([toy_data,label])
直到 fit 功能一切正常。令我惊讶的是,拟合函数给出了以下错误,
ValueError: in user code:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:805 train_function *
return step_function(self, iterator)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:795 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py:788 run_step **
outputs = model.train_step(data)
<ipython-input-232-1cc639e2055c>:182 train_step
keras.losses.mean_squared_error(data, reconstruction)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/losses.py:1197 mean_squared_error
y_true = math_ops.cast(y_true, y_pred.dtype)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/math_ops.py:964 cast
x = ops.convert_to_tensor(x, name="x")
/usr/local/lib/python3.7/dist-packages/tensorflow/python/profiler/trace.py:163 wrapped
return func(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:1540 convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/array_ops.py:1525 _autopacking_conversion_function
return _autopacking_helper(v, dtype, name or "packed")
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/array_ops.py:1444 _autopacking_helper
converted_elem = _autopacking_helper(elem, dtype, str(i))
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/array_ops.py:1461 _autopacking_helper
return gen_array_ops.pack(elems_as_tensors, name=scope)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_array_ops.py:6398 pack
"Pack", values=values, axis=axis, name=name)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/op_def_library.py:750 _apply_op_helper
attrs=attr_protos, op_def=op_def)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py:592 _create_op_internal
compute_device)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:3536 _create_op_internal
op_def=op_def)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:2016 __init__
control_input_ops, op_def)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:1856 _create_c_op
raise ValueError(str(e))
ValueError: Dimension 1 in both shapes must be equal, but are 100 and 1. Shapes are [2,100] and [2,1].
From merging shape 0 with other shapes. for '{{node Cast/x/0}} = Pack[N=2, T=DT_FLOAT, axis=0](IteratorGetNext, IteratorGetNext:1)' with input shapes: [2,100], [2,1].
我不明白为什么它不能在轴 1 上合并 [2,100] 和 [2,1] 它应该产生 [2,101],我弄错了吗?
这是 plot_model 为编码器产生的结果
PS:我尝试使用连接轴,但没有一个值起作用。
解决方案
问题出在重建中,它通过训练步骤中的以下连接得到解决
def train_step(self, data):
with tf.GradientTape() as tape:
# exp_data, label_data = data
z_mean, z_log_var, z, zc = self.encoder(data)
#form_data = np.concatenate(data)
reconstruction = self.decoder(zc)
data_cat = layers.concatenate([data[0][0],data[0][1]], axis=1)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.mean_squared_error(data_cat, reconstruction)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
推荐阅读
- reactjs - 使用 Axios.post() 从本地文件中获取值
- json - JSON返回单个字符串时如何使用Swift的JSONSerialization
- tensorflow2.0 - 模块“tensorflow”没有属性“GraphKeys”
- c - 使用 fgetc 忽略空格后的单词
- ruby-on-rails - 如何让用户只编辑/删除自己的项目?Ruby on Rails
- vue.js - VueJS在循环中的元素上转换?
- hook - 如何在 whmcs 挂钩中获取局部变量值?
- opengl - 正态分布函数中潜在活跃微面的浓度
- css - 悬停时平滑更改动画速度?
- javascript - Firebase Passwordless Auth 在单击链接时显示未找到站点页面