python - 谁能解释我在 tensorflow keras 中的问题?
问题描述
import tensorflow as tf
from tensorflow import keras
import numpy as np
training_inputs = np.array([[0,0,1],
[1,1,1],
[1,0,1],
[0,1,1]])
training_outputs = np.array([[0,1,1,0]])
model = keras.Sequential([
keras.layers.Flatten(input_shape=(1,3)),
keras.layers.Dense(1,activation="sigmoid")
])
model.compile(optimizer = "rmsprop",
loss = "binary_crossentropy",
metrics = ["accuracy"])
model.fit(training_inputs,training_outputs,epochs=1)
prediction = model.predict(np.array([[1,1,0]]))
打印(预测)
它有这些问题
回溯(最近一次通话最后):文件“C:/Users/Αλέξης/Desktop/Youtube/test.py”,第 21 行,model.fit(training_inputs,training_outputs,epochs=1) 文件“C:\Users\Αλέξης \AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 709 行,适合 shuffle=shuffle)文件“C:\Users\Αλέξης\AppData\ Local\Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 2651 行,在 _standardize_user_data exception_prefix='input') 文件“C:\Users\Αλέξης\AppData\Local \Programs\Python\Python37\lib\site-packages\tensorflow\python\keras\engine\training_utils.py",第 376 行,在 standardize_input_data 'with shape ' + str(data_shape)) ValueError:检查输入时出错:预期 flatten_input 有 3 个维度,但得到了形状为 (4, 3) 的数组
任何人都可以帮忙吗?
解决方案
我删除了训练输出数组的方括号,以便训练输入中的每个条目都有一个关联的类,而不仅仅是第一个有 4 个
我还删除了 flatten 层并将 input_shape 参数设置为dense layer
import tensorflow as tf
from tensorflow import keras
import numpy as np
training_inputs = np.array([[0,0,1],
[1,1,1],
[1,0,1],
[0,1,1]])
training_outputs = np.array([0,1,1,0])
model = keras.Sequential([
keras.layers.Dense(1,activation="sigmoid",input_shape=(3,))
])
model.compile(optimizer = "rmsprop",
loss = "binary_crossentropy",
metrics = ["accuracy"])
model.fit(training_inputs,training_outputs,epochs=1)
推荐阅读
- sql - SQL Server - 如果所有列之一是唯一的,请选择
- sql-server - SQL 数据库 防止记录重复要走多远?
- json - D/Volley:[1] 5.onErrorResponse:产品
- php - 如何通过位置 ID 或纬度和经度获取 instagram 媒体?
- python - 如何使用 BeautifulSoup 从 Oddshark.com 解析这个 javascript?
- css - 如何将继承值存储在 CSS 变量(又名自定义属性)中?
- python - 打印非 Ascii 列值,python-spark
- angular - 将数据传递给sidemenu Ionic
- javascript - 将任务添加到任务列表后禁用按钮
- python - 为什么我不断收到此错误“RuntimeWarning:int_scalars 中遇到溢出”