python - 向 cnn 的中间层添加一个常量值
问题描述
我想在学习过程中在cnn中的中间层的输出层添加一个常数矩阵,然后将其发送到下一层。我把我的代码放在这里并使用 Add 函数,但它会产生错误。我应该怎么办?使用 Add 是不是一个真正的解决方案?
from keras.layers import Input, Concatenate, GaussianNoise
from keras.layers import Conv2D
from keras.models import Model
from keras.datasets import mnist
from keras.callbacks import TensorBoard
from keras import backend as K
from keras import layers
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as Kr
import numpy as np
w_main = np.random.randint(2,size=(1,4,4,1))
w_main=w_main.astype(np.float32)
w_expand=np.zeros((1,28,28,1),dtype='float32')
w_expand[:,0:4,0:4]=w_main
w_expand.reshape(1,28,28,1)
#-----------------------encoder------------------------------------------------
#------------------------------------------------------------------------------
image = Input((28, 28, 1))
conv1 = Conv2D(8, (5, 5), activation='relu', padding='same')(image)
conv2 = Conv2D(4, (3, 3), activation='relu', padding='same')(conv1)
conv3 = Conv2D(2, (3, 3), activation='relu', padding='same')(conv2)
encoded = Conv2D(1, (3, 3), activation='relu', padding='same')(conv3)
encoder=Model(inputs=image, outputs=encoded)
encoder.summary()
#-----------------------adding w---------------------------------------
encoded_merged=Kr.layers.Add(encoded,w_expand)
#-----------------------decoder------------------------------------------------
#------------------------------------------------------------------------------
#encoded_merged = Input((28, 28, 2))
x = Conv2D(2, (5, 5), activation='relu', padding='same')(encoded_merged)
x = Conv2D(4, (3, 3), activation='relu', padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu',padding='same')(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='decoder_output')(x)
decoder=Model(inputs=encoded_merged, outputs=decoded)
decoder.summary()
产生的错误是:
TypeError: init () 接受 1 个位置参数,但给出了 3 个我很着急。请帮我解决一下这个。
解决方案
您以错误的方式使用图层,这是正确的方式:
encoded_merged=Kr.layers.Add()([encoded,w_expand])
推荐阅读
- sql-server - 仅当该列没有不同的值时才更新表列
- java - 不公平锁如何比公平锁有更好的性能?
- vue.js - 加载和渲染状态组件 Vue
- javascript - 如何从 javascript 数组中选择具有特定类的特定 html 段落?
- node.js - NextJS SSR SSG 用户登录后或更新购物车时
- c# - C# 区域平滑问题。路径是平滑的,直到区域设置为路径
- php - 创建动态内容时,引导程序中的表崩溃
- python - matplotlib fill_between() 在第一个/最后一个点绘制不需要的额外线条
- amazon-web-services - 强制 AMAZON S3 Bucket 下载 https
- awk - 使用 sed 重新格式化行,将一个部分复制到单独行上其他部分的开头