python - 如何处理发生在 Keras 的自定义层中的代码错误?
问题描述
我想在 Keras 中制作一个自定义图层。在这个例子中,我使用一个变量来乘以张量,但我得到的错误是
在 /keras/engine/training_arrays.py 中,第 304 行,在 predict_loop outs[i][batch_start:batch_end] = batch_out ValueError: 无法将输入数组从形状 (36) 广播到形状 (2)。
实际上我已经检查了这个文件,但我什么也没得到。我的自定义层有问题吗?
#the definition of mylayer.
from keras import backend as K
import keras
from keras.engine.topology import Layer
class mylayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(mylayer, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(name = 'kernel',
shape=(1,),dtype='float32',trainable=True,initializer='uniform')
super(mylayer, self).build(input_shape)
def call(self, inputs, **kwargs):
return self.kernel * inputs[0]
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[1])
#the test of mylayer.
from mylayer import mylayer
from tensorflow import keras as K
import numpy as np
from keras.layers import Input, Dense, Flatten
from keras.models import Model
x_train = np.random.random((2, 3, 4, 3))
y_train = np.random.random((2, 36))
print(x_train)
x = Input(shape=(3, 4, 3))
y = Flatten()(x)
output = mylayer((36, ))(y)
model = Model(inputs=x, outputs=output)
model.summary()
model.compile(optimizer='Adam',loss='categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train, y_train, epochs=2)
hist = model.predict(x_train,batch_size=2)
print(hist)
print(model.get_layer(index=1).get_weights())
#So is there some wrong in my custom error?
特别是,当我训练这个网络时,没关系,但是当我尝试使用“预测”时,它是错误的。
解决方案
你的形状self.kernel * inputs[0]
是(36,)
,但你的期望是(?,36)
。更改:
def call(self, inputs, **kwargs):
return self.kernel * inputs
如果要输出 的权重mylayer
,则应设置index=2
。
推荐阅读
- azure - 特定时间的 Log Analytics 警报规则
- javascript - 反应嵌套路由页面未按预期正确呈现
- unit-testing - 如何处理测试单元中的中止
- r - 从 R 中的光栅砖中选择的平均运行长度
- ruby-on-rails-5.2 - 升级到 ruby 2.5.0 和 Rails 5.2 后,to_h 作为范围失败
- python - 尝试从嵌入 Python 文档运行示例时出现“致命 Python 错误”
- html - 新网站 FeedMySheep
- r - 使用 R,我想遍历每一行并为每一行创建相应的卡方结果
- installation - 复制谷歌colab的python环境的最佳方法
- python-3.x - 测试尝试 - 捕获 Django 函数。需要异常测试