python - 张量流中的保真度自定义损失函数 - 输入形状问题
问题描述
作为多目标基本神经网络(到目前为止,GRU)的一部分,我在输入自定义损失函数时遇到问题。
我有一个多目标函数,它将扁平 (,16) 形状 (4x4) 矩阵(“输入矩阵”)作为输入,并生成 10 个 4x4 矩阵序列(通过扁平输出形状 (,160) 表示)作为使用“mse”损失的第一个目标。模型的这一部分,data.lab
作为训练数据和inputreshape
标签数据工作正常。作为第二个目标,我需要将此 (,160) 张量转换为 (10,4,4) 张量(重塑为 10 个 4x4 矩阵),然后将它们组合在一起以获得 (4,4) 矩阵“输出矩阵”(通过 productlayer 自定义层完成)。然后需要将此输出矩阵输入到一个函数中,以通过自定义损失函数“fidelity2”通过与来自 的原始输入进行比较来计算“保真度” data.lab
。我通过输入来做到这一点data.lab
第二次,但现在作为一个标签。但是(1)model.summary()似乎显示productlayer的输出是(4,4),而我认为它应该是(None,4,4)和(2)我收到错误:
InvalidArgumentError: In[0] is not a matrix. Instead it has shape [10,4,4]
[[{{node loss_7/fidout_loss/ArithmeticOptimizer/FoldTransposeIntoMatMul_matmul}}]]
自定义损失函数中的 y_true 似乎是从 model.fit 中获取整个标签集(参见下面的代码),而不是批处理,因此这意味着它转储了 (10,4,4) 标签集,但正在尝试比较针对 (4,4) y_pred。不确定如何解决此问题。代码如下。最终,我需要能够计算保真度 - 实际上最终是不忠 - 以这种方式用于任意批量大小(我将 model.fit 中的 batchsize 设置为 1 只是为了让它工作)。建议表示赞赏。
def fidelity2(y_true, y_pred):
y_truetp = tf.transpose(y_true)
t1 = (y_truetp @ y_pred)
tr = tf.trace(t1)
mxdim = tf.cast(tf.shape(y_pred)[0], tf.float32)
fidelity = (tf.abs((tr)** 2) / mxdim ** 2)
return fidelity
x = layers.Input(shape=(data.realdim,data.realdim), name='input1', batch_size=None)
x1 = layers.GRU(data.Uj_dim, return_sequences=True)(x)
x1 = layers.Dropout(rate=0.2)(x1)
x1 = layers.GRU(data.Uj_dim, return_sequences=True)(x1)
x1 = layers.Dropout(rate=0.2)(x1)
x1 = layers.GRU(data.Uj_dim, return_sequences=True)(x1)
x1 = layers.Flatten()(x1)
y = layers.Dense(160, activation='relu', name='output_data')(x1)
xreshape = layers.Reshape((4,4))(x)
y2 = productlayer(trainable = True, name="fidout")(x1)
model = tf.keras.models.Model(inputs=x, outputs=[y,y2])
batchsz = 1
#===============
model.compile(optimizer='adam', loss=['mse',fidelity2], metrics=['mse',fidelity2])
model.summary()
model.fit(data.lab, [inputreshape,data.lab], epochs=2, batch_size=batchsz, validation_split = 0, shuffle=False, steps_per_epoch=1)
#==============```
Model: "model_13"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input1 (InputLayer) [(None, 4, 4)] 0
__________________________________________________________________________________________________
gru_39 (GRU) (None, 4, 40) 5400 input1[0][0]
__________________________________________________________________________________________________
dropout_26 (Dropout) (None, 4, 40) 0 gru_39[0][0]
__________________________________________________________________________________________________
gru_40 (GRU) (None, 4, 40) 9720 dropout_26[0][0]
__________________________________________________________________________________________________
dropout_27 (Dropout) (None, 4, 40) 0 gru_40[0][0]
__________________________________________________________________________________________________
gru_41 (GRU) (None, 4, 40) 9720 dropout_27[0][0]
__________________________________________________________________________________________________
flatten_13 (Flatten) (None, 160) 0 gru_41[0][0]
__________________________________________________________________________________________________
output_data (Dense) (None, 160) 25760 flatten_13[0][0]
__________________________________________________________________________________________________
fidout (productlayer) (4, 4) 0 flatten_13[0][0]
==================================================================================================
Total params: 50,600
Trainable params: 50,600
Non-trainable params: 0
__________________________________________________________________________________________________
Epoch 1/2
---------------------------------------------------------------------------
InvalidArgumentError...InvalidArgumentError: In[0] is not a matrix. Instead it has shape [10,4,4]
[[{{node loss_7/fidout_loss/ArithmeticOptimizer/FoldTransposeIntoMatMul_matmul}}]]
解决方案
推荐阅读
- python - 在异步生成器上使用 next()
- javascript - Is it possible to detect when a user "clicks" an element with the keyboard (while navigating witht the tab key)?
- graphql - prisma graphql pass argument to a nested query
- python - 将一个匀称的多边形切割成 N 个大小相等的多边形
- javascript - Difference between current time and UTC time format in javascript with Date-fns
- javascript - Updating firestore data based on some specific condition in Angular. SnapshotChanges() Issue
- php - After updating to Laravel 8.x from 7.9.2, $user -> links() has a problem with the user interface, bootstrap maybe
- python - Jinja for 循环没有循环正确的次数
- javascript - 我们可以向 Telegram 机器人添加电子书功能以查看冗长的消息吗?
- arduino - 无法使用 Modbus RTU 读取 Arduino 的寄存器