python - 将自定义层和 py_function 与 opencv 方法一起使用
问题描述
我的开发环境是
- 视窗 10
- Python 3.6.8
- 张量流 1.13.1
我的目标是实现一个层,可以将每个 cnn 过滤器转换为 hu 矩不变量(每个过滤器 -> 7 维值)
所以,我想使用 Opencv 的 Humoment 方法
这是我定义的图层:
class MomentLayer(tf.keras.layers.Layer):
def __init__(self):
super(MomentLayer, self).__init__()
def build(self, input_shape):
self.oshape = input_shape
super(MomentLayer, self).build(input_shape)
def call(self, inputs, **kwargs):
xout = tf.py_function(image_tensor_func, (inputs,), 'float32', name='Cvopt')
xout.set_shape(tf.TensorShape((None, self.oshape[-1] * 7)))
return xout
def compute_output_shape(self, input_shape):
return tf.TensorShape((None, input_shape[-1] * 7))
我的 py_function 是
def image_tensor_func(img4d):
img4dx = tf.transpose(img4d, [0, 3, 1, 2])
all_data = np.array([])
for img3dx in img4dx:
tmp = np.array([])
for efilter in img3dx:
hu = cv2.HuMoments(cv2.moments(efilter.numpy())).flatten()
if tmp.shape[0] == 0:
tmp = hu
else:
tmp = np.concatenate((tmp, hu), axis=None)
if all_data.shape[0] == 0:
all_data = tmp
else:
all_data = np.vstack((all_data, tmp))
x = tf.convert_to_tensor(all_data, dtype=tf.float32)
return x
最后,我定义网络
input = tf.keras.layers.Input(shape=(10, 10, 1))
conv1 = tf.keras.layers.Conv2D(filters=3, kernel_size=5, activation=tf.nn.relu)(input)
test_layer = MomentLayer()(conv1)
dense1 = tf.keras.layers.Dense(units=12, activation=tf.nn.relu)(test_layer)
output = tf.keras.layers.Dense(units=10, activation=tf.nn.sigmoid)(dense1)
model = tf.keras.models.Model(inputs=input, outputs=output)
model.compile(optimizer=tf.train.RMSPropOptimizer(0.01),
loss=tf.keras.losses.categorical_crossentropy,
metrics=[tf.keras.metrics.categorical_accuracy])
print(model.summary())
和 model.summary() 工作正常!
但是当我尝试提供数据时
我有错误
tensorflow.python.framework.errors_impl.InvalidArgumentError: transpose 需要一个大小为 0 的向量。但 input(1) 是一个大小为 4 的向量 [[{{node training/TFOptimizer/gradients/Relu_grad/ReluGrad-0-TransposeNHWCToNCHW-LayoutOptimizer} }]] [操作:StatefulPartitionedCall]
我很确定数据的形状是正确的。
我想知道tensorflow不能写出这样的代码。
解决方案
您输入的数据由特征和标签组成。因此,您需要确保标签的形状也是正确的。
推荐阅读
- ruby-on-rails - 使用 Chartkick 下载图表
- ios - 取消对 Alamofire 图像的请求
- php - 如何在php中将xml文件显示为代码?
- ruby-on-rails - Rails 钩子 - `deliver.action_mailer` 未触发
- r - 如何在 r 中将数据框中的列显示为条形图?
- python - Python:如何从电子邮件中的链接保存网页(作为 html 文件)
- azure - AzureAD 中使用 REST API 进行纵向扩展和缩减的权限
- php - 使用 Silex、XAMPP 和多文件夹结构重写 htaccess/url
- entity-framework - Visual Studio 无法加载实体框架 PowerShell 脚本,因为它的操作被软件限制策略阻止
- vue.js - 在 Vuetify 数据表中构建超链接 - 也许是 v-bind?