python - ValueError:检查目标时出错:预期dense_2有2维,但得到的数组形状为(1,)
问题描述
我正在准备一些机器学习来识别给定图像是否包含特征(1)或不包含特征(0)。但是标签形状和模型的输出似乎不同。
所有信息都包含在一个张量流数据集中:
path_ds = tf.data.Dataset.from_tensor_slices(allImages)
img_ds = path_ds.map(preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
label_ds = label_ds.map(conversion)
ds = tf.data.Dataset.zip((img_ds, label_ds))
ds = ds.shuffle(buffer_size=image_count).repeat().batch(5).prefetch(10)
iterator = ds.make_one_shot_iterator()
ds_x, ds_y = iterator.get_next()
模型是这样的:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(704, 480, 3)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(ds_x, ds_y, epochs=3, steps_per_epoch=1, verbose=3)
allImages 和标签是列表。所有图片都包含图片的路径,后面会进行预处理,labels是1或0的列表
我收到的错误消息如下:
ValueError:检查目标时出错:预期dense_2有2维,但得到的数组形状为(1,)
完整代码:
import tensorflow as tf
import os
import glob
import csv
import lab
from PIL import ImageFile
import datetime
import pandas as pd
def preprocess_image(path):
img = tf.read_file(path)
img = tf.image.decode_jpeg(img, channels=3, try_recover_truncated=True, acceptable_fraction=0.9)
img = tf.cast(img, tf.float32)
return img
def conversion(label):
label = tf.cast(label, tf.float32)
return label
def get_keys(d):
if not d.keys():
return None
if list(d.keys())[0] == 'yes':
return 1
if list(d.keys())[0] == 'no':
return 0
def change_range(image,label):
return 2*image-1, label
print("Start", datetime.datetime.now())
orig_path = 'PATH'
CAM = 'CAM'
allImages = []
labels = []
print("Parse labels", datetime.datetime.now())
for folders in os.listdir(orig_path):
print("Reading one folder")
df = pd.read_csv(orig_path + '/' + folders + '/' + folders + '.csv')
df.drop(['file_size', 'region_count', 'region_id', 'region_shape_attributes', 'region_attributes'], axis=1, inplace=True)
df.file_attributes = df.file_attributes.str.replace('true', 'True')
labelsDF = pd.DataFrame(df.file_attributes.apply(eval).values.tolist())
mask = labelsDF.time == 'day'
labels2 = pd.DataFrame(labelsDF[mask].drop(['cloud', 'feature', 'light', 'problems', 'time'], axis=1).eruption.apply(get_keys))
df.file_attributes = labels2
df = df[mask]
df.dropna(inplace=True)
mask2 = df['filename'].isin(os.listdir(orig_path + '/' + folders))
df = df[mask2]
df.filename = '/home/mitiga/Images/' + folders +'/' + df.filename
labels = labels + df['file_attributes'].tolist()
allImages = allImages + df['filename'].tolist()
image_count = len(allImages)
path_ds = tf.data.Dataset.from_tensor_slices(allImages)
img_ds = path_ds.map(preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
label_ds = label_ds.map(conversion)
print(label_ds.output_shapes)
ds = tf.data.Dataset.zip((img_ds, label_ds))
ds = ds.shuffle(buffer_size=image_count).repeat().batch(5).prefetch(10)
iterator = ds.make_one_shot_iterator()
ds_x, ds_y = iterator.get_next()
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(704, 480, 3)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='softmax'))
print(model.summary())
print(ds_y.shape)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(ds_x, ds_y, epochs=3, steps_per_epoch=1, verbose=3)
解决方案
推荐阅读
- javascript - 文本框中的数组列表用于表中的日期类型部分视图 mvc
- java - 关于ArrayList的机制问题
- oracle - 有没有办法制作一个 PLSQL 脚本来列出表中每条记录为 NULL 的所有列?
- visual-studio-code - 自定义调试器停止/重启按钮行为
- php - “如何修复“找不到类'SplString'”
- python - annotate 的 django 结果不包含在序列化程序中
- python - 使用 Cuda 时:TypeError:('关键字参数不理解:','激活')
- forms - 上传的图片太大:出现循环引用
- ubuntu - Javafx 任务栏图标 - 问题
- java - 如何在 Java 中制作循环列表