python - 可训练矩阵乘法层
问题描述
我正在尝试在 TensorFlow 中构建一个(自定义)可训练矩阵乘法层,但事情并没有解决......更准确地说,我的模型应该如下所示:
x -> A(x) x
其中 A(x) 是一个前馈网络,其值在 nxn 矩阵中(因此取决于输入 x),A(x) 是矩阵向量乘法。
这是我编写的代码:
class custom_layer(tf.keras.layers.Layer):
def __init__(self, units=16, input_dim=32):
super(custom_layer, self).__init__()
self.units = units
def build(self, input_shape):
self.Tw1 = self.add_weight(name='Weights_1 ',
shape=(input_shape[-1], input_shape[-1]),
initializer='GlorotUniform',
trainable=True)
self.Tw2 = self.add_weight(name='Weights_2 ',
shape=(input_shape[-1], (self.units)**2),
initializer='GlorotUniform',
trainable=True)
self.Tb = self.add_weight(name='basies',
shape=(input_shape[-1],),
initializer='GlorotUniform',#Previously 'ones'
trainable=True)
def call(self, input):
# Build Vector-Valued Feed-Forward Network
ffNN = tf.matmul(input, self.Tw1) + self.Tb
ffNN = tf.nn.relu(ffNN)
ffNN = tf.matmul(ffNN, self.Tw2)
# Map to Matrix
ffNN = tf.reshape(ffNN, [self.units,self.units])
# Multiply Matrix-Valued function with input data
x_out = tf.matmul(ffNN,input)
# Return Output
return x_out
现在我建立模型:
input_layer = tf.keras.Input(shape=[2])
output_layer = custom_layer(2)(input_layer)
model = tf.keras.Model(inputs=[input_layer], outputs=[output_layer])
# Compile Model
#----------------#
# Define Optimizer
optimizer_on = tf.keras.optimizers.SGD(learning_rate=10**(-1))
# Compile
model.compile(loss = 'mse',
optimizer = optimizer_on,
metrics = ['mse'])
# Fit Model
#----------------#
model.fit(data_x, data_y, epochs=(10**1), verbose=0)
然后我收到此错误消息:
InvalidArgumentError: Input to reshape is a tensor with 128 values, but the requested shape has 4
[[node model_62/reconfiguration_unit_70/Reshape (defined at <ipython-input-176-0b494fa3fc75>:46) ]] [Op:__inference_distributed_function_175181]
Errors may have originated from an input operation.
Input Source operations connected to node model_62/reconfiguration_unit_70/Reshape:
model_62/reconfiguration_unit_70/MatMul_1 (defined at <ipython-input-176-0b494fa3fc75>:41)
Function call stack:
distributed_function
想法: 网络维度似乎有问题,但我不知道什么/如何修复它......
解决方案
推荐阅读
- android - FirebaseInstanceId.getInstance().getToken() = null 什么时候?
- lua - Lua 光线投射异常
- python - python比较字典的值并丢弃它们
- c# - UNC 路径机
- java - 如何创建一个输入流,其源是接收到的数据包(ByteArrayInputStream 和 DataInputStream)?
- javascript - 如何使用单击的选项卡设置 isActive 动态属性 React
- r - 闪亮应用中的多个 group_by
- .net - 没有这样的配置属性:“sasl.kerberos.kinit.cmd”
- json - 具有列表属性的杰克逊 JSON
- python - Tkinter OptionMenu 无法正确显示