python - TensorFlow 单层网络 Conv2d;无效参数:转置需要大小为 3 的向量。但 input(1) 是大小为 4 的向量
问题描述
我无法弄清楚尺寸不匹配的位置。似乎所有图层都是正确的形状。Conv2d 需要一个大小为 4 的向量。
#Here is the network structure
# inside parse_Depth_input I print the image shape (not sure why it's backwards but I did try reversing just in case)
d_image (228, 304, 1)
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 304, 228, 3) 30
_________________________________________________________________
reshape (Reshape) (None, 69312, 3) 0
=================================================================
Total params: 30
Trainable params: 30
Non-trainable params: 0
(0) 无效参数:transpose 需要一个大小为 3 的向量。但 input(1) 是一个大小为 4 的向量 [[{{node gradient_tape/simple_cnn/sequential/conv2d/Conv2D/Conv2DBackpropFilter-0-TransposeNHWCToNCHW-LayoutOptimizer}}] ] [[cond/then/_25/cond/cond/then/_124/cond/cond/remove_squeezable_dimensions/Equal/_110]] (1)无效参数:转置需要一个大小为3的向量。但输入(1)是一个向量大小为 4 [[{{node gradient_tape/simple_cnn/sequential/conv2d/Conv2D/Conv2DBackpropFilter-0-TransposeNHWCToNCHW-LayoutOptimizer}}]]
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
from tensorflow.keras import datasets, layers, models, preprocessing
import os
from natsort import natsorted
from tensorflow.keras.models import Model
BATCH_SIZE = 32
EPOCHS = 15
LEARNING_RATE = 1e-4
#jpegs with values from 0 to 255
img_dir = ".../normalized_imgs"
# .npy files of size (69312,3)
pts_dir = ".../normalized_pts"
img_files = [os.path.join(img_dir, f)
for f in natsorted(os.listdir(img_dir))]
pts_files = [os.path.join(pts_dir, f)
for f in natsorted(os.listdir(pts_dir))]
img = Image.open(img_files[0])
pts = np.load(pts_files[0])
def parse_img_input(img_file, pts_file):
def _parse_input(img_file, pts_file):
# get image
d_filepath = img_file.numpy().decode()
d_image_decoded = tf.image.decode_jpeg(tf.io.read_file(d_filepath), channels=1)
d_image = tf.cast(d_image_decoded, tf.float32) / 255.0
# get numpy data
pts_filepath = pts_file.numpy().decode()
pts = np.load(pts_filepath, allow_pickle= True)
print("d_image ",d_image.shape )
return d_image, pts
return tf.py_function(_parse_input,
inp=[img_file, pts_file],
Tout=[tf.float32, tf.float32])
class SimpleCNN(Model):
def __init__(self):
super(SimpleCNN, self).__init__()
input_shape = (img.size[0], img.size[1], 1)
self.model = model = models.Sequential()
model.add(tf.keras.Input(shape= input_shape))
model.add(layers.Conv2D(3, (3,3), padding='same'))
model.add(layers.Reshape((pts.shape[0], pts.shape[1])))
# split input data into train, test sets
X_train_file, X_test_file, y_train_file, y_test_file = train_test_split(img_files, pts_files,
test_size=0.2,
random_state=0)
model = SimpleCNN()
dataset_train = tf.data.Dataset.from_tensor_slices((X_train_file, y_train_file))
dataset_train = dataset_train.map(parse_img_input)
dataset_test = tf.data.Dataset.from_tensor_slices((X_test_file, y_test_file))
dataset_test = dataset_test.map(parse_img_input)
model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE), loss= tf.losses.MeanSquaredError(), metrics= [tf.keras.metrics.get('accuracy')])
model.fit(dataset_train, epochs=EPOCHS, shuffle=True, validation_data= dataset_test)
奇怪的是,如果我尝试将形状重塑为与输入相同的形状,尽管它正确显示了输入大小,但仍会出现错误
input_shape = (img.size[0], img.size[1], 1)
model.add(tf.keras.Input(shape= input_shape))
model.add(layers.Reshape((input_shape[0], input_shape[1], 1)))
reshape 的输入是一个有 69312 个值的张量,但请求的形状有 15803136
解决方案
我的问题是我没有事先对数据集进行批处理。我添加了这个
def prepare_dataset(ds, shuffle = False):
if shuffle:
ds = ds.shuffle(buffer_size=500)
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
dataset_train = prepare_dataset(dataset_train, shuffle=True)
推荐阅读
- ios - 如何快速实施实时条纹苹果支付?Stripe 测试密钥工作正常但无法使用,我该如何解决?
- node.js - React js - 反应中的网络套接字链接代理问题
- javascript - 如何使用JavaScript从数组中删除字符后的数字和字母
- c++ - 使用循环展开加速 do-while 循环
- python - 根据规则操作列中的值
- node.js - 无法在 create-react-app 中运行 npm start
- flutter - SqfliteDatabaseException (DatabaseException(unrecognized token: "498a" (code 1): , while compile: DELETE FROM Products WHERE
- asp.net-core - Identity.IsAuthenticated 在 HTTPGET 控制器中的 SignInAsync() 之后返回 false
- powershell - 如何在 Powershell 别名中运行连续命令?
- python - 如何使用 sympy 求解方程?