python - Conv2D CNN 边缘检测脚本返回空白图像
问题描述
我想给图中的油、水和塑料上色。
目前我分割了训练图像(只使用正确着色的部分)。然后我训练一个 conv2D 网络并绘制预测。
当我运行它时,我会得到空白的蓝色或黑色图像作为回报。
请告知:
- - 我的网络是否合适,
- - 我应该使用什么参数
- - 我应该使用什么训练图像。
#IMPORT AND SPLIT
from cam_img_split import cam_img_split
import cv2
import numpy as np
img_tr_in=cv2.imread('frame 1.png')
img_tr_out=cv2.imread('Red edge.png')
seg_shape=[32,32]
tr_in=cam_img_split(img_tr_in,seg_shape)
tr_out=cam_img_split(img_tr_out,seg_shape)
pl=[4,20]
##################### NEURAL NETWORK
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense,Dropout,Conv2D, MaxPooling2D
from keras.optimizers import adam
import matplotlib.pyplot as plt
pad=3
input_shape=(seg_shape[0]+2*pad,seg_shape[1]+2*pad,3)
model = Sequential()
model.add(Conv2D(32, (3, 3),input_shape=input_shape, activation='relu'))
model.add(Conv2D(64,(3, 3), activation='relu'))
model.add(Conv2D(3, (3, 3),input_shape=input_shape, activation='relu'))
model.compile(optimizer=adam(lr=0.001), loss='mean_squared_error', metrics=['accuracy'])
tr_in_sel=tr_in[0:pl[0],0:pl[1],:,:,:]
tr_out_sel=tr_out[0:pl[0],0:pl[1],:,:,:]
tr_in_sel_flat=tr_in_sel.reshape([pl[0]*pl[1],seg_shape[0],seg_shape[1],3])
tr_out_sel_flat=tr_out_sel.reshape([pl[0]*pl[1],seg_shape[0],seg_shape[1],3])
tr_in_sel_flat_norm=tr_in_sel_flat/255
tr_out_sel_flat_norm=tr_out_sel_flat/255
from cam_pad import cam_pad
tr_in_sel_flat_norm_pad=np.zeros(tr_in_sel_flat.shape+np.array([0,2*pad,2*pad,0]))
for n3 in range(0,tr_in_sel_flat.shape[0]):
for n4 in range(0,tr_in_sel_flat.shape[3]):
tr_in_sel_flat_norm_pad[n3,:,:,n4]=cam_pad(tr_in_sel_flat_norm[n3,:,:,n4], pad)
model.fit(tr_in_sel_flat_norm_pad, tr_out_sel_flat_norm, epochs=10, batch_size=int(pl[0]/2),shuffle=True)
n_ch=10
img_check=np.zeros([n_ch,seg_shape[0]+2*pad,seg_shape[1]+2*pad,3])
for n8 in range(0,n_ch):
for n5 in range(0,3):
img_check[n8,:,:,n5]=cam_pad(tr_in_sel_flat_norm[n8,:,:,n5],pad)
pred = model.predict(img_check/255)
pred_img=(pred.reshape([n_ch,seg_shape[0],seg_shape[1],3]))
for n9 in range(1,n_ch):
plt.subplot(n_ch,1,n9)
plt.imshow(pred_img[n9-1,:,:,:])
plt.show()
解决方案
我不小心对我的数据进行了两次标准化(除以 255)。这导致网络对空白图像进行训练,从而产生空白图像。
推荐阅读
- flutter - 在警报对话框中使用提供程序不起作用
- api - 驱动 API 更新:
- javascript - React useEffect Api call with fetch 给出 SyntaxError: JSON.parse: unexpected end of data at line 1 column 1 JSON data
- javascript - 带有承诺的 JavaScript 中的事件循环
- javascript - 从 JSON 在反应应用程序中加载自定义主题数据
- python - 我的 getTerm 没有返回用户值,我的方法不完整 python
- javascript - React 表单 onChange 是后面的一个字符
- vaadin - 如何在 Vaadin Flow 的 FormLayout 中添加空格?
- r - 如何使用带有传单地图的嵌套模块
- python - Django 如何发送信号以从组中添加或删除用户?