python - 如何将多个数组提供给 keras 机器学习算法?
问题描述
我打算制作一个保费预测器,它根据包括性别、性别、BMI 等在内的多个因素来预测您的保险费。(总共 6 个)我有数据,但我不知道如何喂多个数组给它。这是代码-
import tensorflow as tf
import numpy as np
from tensorflow import keras
import pandas as pd
a=0
file=pd.read_csv(r"""C:\Users\lavni\OneDrive\Desktop\proj.csv""",sep=',')
Age=pd.DataFrame(file,columns=['age']).to_numpy()
Sex=pd.DataFrame(file,columns=['sex']).to_numpy()
BMI=pd.DataFrame(file,columns=['bmi']).to_numpy()
Children=pd.DataFrame(file,columns=['children']).to_numpy()
Smoker=pd.DataFrame(file,columns=['smoker']).to_numpy()
Region=pd.DataFrame(file,columns=['region']).to_numpy()
Charges=pd.DataFrame(file,columns=['charges']).to_numpy()
Data=[Age,Sex,BMI,Children,Smoker,Region]
model=keras.Sequential([keras.layers.Dense(units=6,input_shape=[6])])
model.compile(optimizer='sgd',loss='mean_squared_error')
model.fit(Data,Charges)
运行它时,它给了我以下错误:
ValueError: Layer sequential expects 1 inputs, but it received 6 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 1) dtype=int64>, <tf.Tensor 'IteratorGetNext:1' shape=(None, 1) dtype=string>, <tf.Tensor 'IteratorGetNext:2' shape=(None, 1) dtype=float32>, <tf.Tensor 'IteratorGetNext:3' shape=(None, 1) dtype=int64>, <tf.Tensor 'IteratorGetNext:4' shape=(None, 1) dtype=string>, <tf.Tensor 'IteratorGetNext:5' shape=(None, 1) dtype=string>]
我理解这个错误,对它进行了一些研究,我知道它需要是一个元组,但是尽管改变了它,它仍然给出了同样的错误。
提前致谢。
解决方案
这里有几个问题,首先,不要pd.DataFrame
用于从文件中加载单列数据(一维数据)。如果您只需要阅读单列,请使用pd.read_csv
(可选)和参数。squeeze
接下来,您不能使用 python 列表对多个输入进行分组,您必须使用元组,如下所示:
Data = (a,b,c,d) # Not [a,b,c,d]
否则,keras 模型将尝试提供您的数据集,a
然后b
以此类推,而不是[a[0], b[0], ...]
. 这似乎是您的目标。
接下来,在您的特定情况下,您似乎不需要多输入 keras 模型。你所需要的只是input_shape
你的第一层的有效。这就是模特抱怨的原因。它需要一个输入,而您正试图插入其中的 6 个。将输入转换为单个数组,例如
data = dataframe.to_numpy() #dataframe must have 6 columns!
targets = targets.to_numpy() #single column
...
model.fit(data,targets)
使用此信息更新您的代码,并在出现其他问题时询问我。
干杯!
推荐阅读
- reactjs - React 中的 componentDidMount 函数
- javascript - Webdriver.io - 很可能无法加载规范文件,因为它们依赖于“浏览器”对象
- node.js - 在客户端从 Nodejs 调用函数:未定义要求
- javascript - 在 PHP 中提交后保留下拉值
- python - 编写一个接收字符串并计算元音的while循环?
- sql - 检查密码上次设置日期并在 SQL Server 中发送通知
- azure-devops - 当 Azure Boards 中有未关闭的子工作项时,如何禁止关闭父工作项?
- fullcalendar - FullCalender 可以配置 TimeSlots: Morning, Afternoon & Evening
- laravel - Laravel 7 清理表单数据
- node.js - 在 ExpressJS 中的任何请求之前调用的中间件