tensorflow - 如何在 Keras/tensorflow 中将常量张量添加到输入张量
问题描述
我有一个简单的 CNN,其输入形状为 (5,5,3)。作为第一步,我想在输入中添加一个常量张量。使用下面的代码,我得到 AttributeError: 'NoneType' object has no attribute '_inbound_nodes'
我尝试了一些类似的东西
const_change = Input(tensor=tf.constant([ ...
或者
const_change = Input(tensor=K.variable([ ...
但似乎没有任何效果。非常感谢任何帮助。
from __future__ import print_function
import tensorflow as tf
import numpy as np
import keras
from keras import backend as K
from keras.models import Model
from keras.layers import Input
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
# Python 2.7.10
# keras version 2.2.0
# tf.VERSION '1.8.0'
raw_input = Input(shape=(5, 5, 3))
const_change = tf.constant([
[[5.0,0.0,0.0],[0.0,0.0,-3.0],[-10.0,0.0,0.0],[0.0,0.0,4.0],[-20.0,0.0,0.0]],
[[-15.0,0.0,12.0],[0.0,4.0,0.0],[-3.0,0.0,10.0],[-18.0,0.0,0.0],[20.0,0.0,-6.0]],
[[0.0,0.0,6.0],[0.0,-2.0,-6.0],[0.0,0.0,2.0],[0.0,0.0,-9.0],[7.0,-6.0,0.0]],
[[-3.0,4.0,0.0],[11.0,-12.0,0.0],[0.0,0.0,0.0],[0.0,0.0,7.0],[0.0,0.0,2.0]],
[[0.0,0.0,0.0],[0.0,1.0,-2.0],[4.0,0.0,3.0],[0.0,0.0,0.0],[0.0,0.0,0.0]]])
cnn_layer1 = Conv2D(32, (4, 4), activation='relu')
cnn_layer2 = MaxPooling2D(pool_size=(2, 2))
cnn_layer3 = Dense(128, activation='relu')
cnn_layer4 = Dropout(0.1)
cnn_output = Dense(4, activation='softmax')
proc_input = keras.layers.Add()([raw_input, const_change])
# proc_input = keras.layers.add([raw_input, const_change]) -> leads to the same error (see below)
lay1 = cnn_layer1(proc_input)
lay2 = cnn_layer2(lay1)
lay3 = Flatten()(lay2)
lay4 = cnn_layer3(lay3)
lay5 = cnn_layer4(lay4)
lay_out = cnn_output(lay5)
model = Model(inputs=raw_input, outputs=lay_out)
# -> AttributeError: 'NoneType' object has no attribute '_inbound_nodes'
解决方案
const_change
应该也是Input
一样的raw_input
。您可以创建另一个名为 的输入层const_input
,并将raw_input
和const_input
一起输入到模型中。
...
const_input = Input(tensor=const_change)
...
proc_input = keras.layers.Add()[raw_input, const_input]
...
model = Model(inputs=[raw_input, const_input], outputs=lay_out)
推荐阅读
- tsql - 在特定时间停止重复作业
- javascript - 通过 UI 过滤 JSON 对象数组
- angular - 如何处理角度材料中的窗口滚动?
- sql - 如何基于多列权重排序获得结果
- python - 一个数组中所有唯一值的两个数组中相同值的计数
- python - 如何使用 keras 的 fit 生成器的多处理模式解决多个进程写入同一个文件?
- macos - 在 MAC 上的 Visual Studio 2019 上 GTK# 工具箱为空
- selenium - 带有多个否定条件的运行关键字 If 的语法
- android - 在某些设备上无法在 Android 中设置 textColor
- distributed - 如何维护写操作的顺序?