python - 检查输入时出错:预期 flatten_input 有 4 个维度,但得到了形状为 (404, 13) 的数组
问题描述
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
data = keras.datasets.boston_housing
(x_train , y_train) , (x_test , y_test) = data.load_data()
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28 , 28 )),
keras.layers.Dense(128 , activation="relu"),
keras.layers.Dense(10 , activation="softmax")
])
model.compile(optimizer="adam" , loss="sparse_categorical_crossentropy" , metrics=["accuracy"])
model.fit(x_train , y_train ,epochs=5 )
test_loss , test_acc = model.evaluate(x_test , y_test)
print("tested acc: ", test_acc)
解决方案
Flatten
用于使图像扁平化的二维数据,所以基本上你将二维列表转换为一维列表,所以你应该更改Flatten
为Input
.
第二个错误是声明输入形状。
input_shape=(28 , 28 )
您声明28x28
了 ,但我认为您希望拥有 28 个具有 28 个特征的样本。这是不变的。为了正确地做到这一点,您将输入形状定义为灵活,它将匹配训练和预测中的任意数量的样本。所有你需要做的就是传递一个样本有多少特征
input_shape=(28, )
这就是它的样子
model = keras.Sequential([
keras.layers.Input(input_shape=(28, )),
keras.layers.Dense(128 , activation="relu"),
keras.layers.Dense(10 , activation="softmax")
])
推荐阅读
- pointers - 为什么 Rust 认为泄漏内存是安全的?
- python - 如何在 keras 中操作可训练的张量乘法运算?
- rgraph - 如何解决 RGraph 饼图中的标签冲突问题?
- javascript - React:如何关闭从父组件打开的子模式
- python - 是否可以在按下按钮时从谷歌表格运行 python 脚本?
- c# - Fluent Assertions Should().BeEquivalentTo 只有私有字段
- spring-boot - 如何修复 BeanInstantiationException?
- html - Flex 内容不适合父内容的高度
- cocoa - NSScrollView 截断文本并且没有滚动条
- r - 通过 Rcurl 上传文件