python - 如何在 python 中使用 keras 训练具有列表数组的神经网络
问题描述
我正在尝试使用 tensorflow.keras 训练神经网络,但我不明白如何使用 numpy 列表数组(在 python3 中)训练它。
我试图改变图层的输入形状,但我真的不明白它是如何工作的。
import tensorflow as tf
from tensorflow import keras
import numpy as np
# Create the array of data
train_data = [[1.0,2.0,3.0],[4.0,5.0,6.0]]
train_data_np = np.asarray(train_data)
train_label = [[1,2,3],[4,5,6]]
train_label_np = np.asarray(train_data)
### Build the model
model = keras.Sequential([
keras.layers.Dense(3,input_shape =(3,2)),
keras.layers.Dense(3,activation=tf.nn.sigmoid)
])
model.compile(optimizer='sgd',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
#Train the model
model.fit(train_data_np,train_label_np,epochs=10)
当调用 model.fit 时,错误是“检查输入时出错:预期的 dense_input 具有 3 个维度,但得到的数组的形状为 (2, 3)”。
解决方案
在定义 Keras 模型时,您必须为模型的第一层提供输入形状。
例如,如果您的训练数据有n
行和m
特征,即形状:(n,m),您必须将模型input_shape
的第一Dense
层设置为,(m, )
即模型应该期望m
特征进入其中。
现在来到你的玩具数据,
train_data = [[1.0,2.0,3.0],[4.0,5.0,6.0]]
train_data_np = np.asarray(train_data)
train_label = [[1,2,3],[4,5,6]]
train_label_np = np.asarray(train_label)
在这里,train_data_np.shape
就是(2, 3)
行2
和3
特征,那么你必须像这样定义模型,
model = keras.Sequential([
keras.layers.Dense(3,input_shape =(3, )),
keras.layers.Dense(3,activation=tf.nn.sigmoid)
])
现在,您的标签是[[1,2,3],[4,5,6]]
. 在正常的3
类分类任务中,这将是一个带有1
和0
s 的单热向量。但是让我们把它放在一边,因为这是一个检查 Keras 的玩具示例。
如果目标标签 ie y_train
是 one-hot 那么你必须使用categorical_crossentropy
loss 而不是sparse_categorical_crossentropy
.
所以你可以像这样编译和训练模型
model.compile(optimizer='sgd',loss='categorical_crossentropy',metrics=['accuracy'])
#Train the model
model.fit(train_data_np,train_label_np,epochs=10)
Epoch 1/10
2/2 [==============================] - 0s 61ms/step - loss: 11.5406 - acc: 0.0000e+00
Epoch 2/10
2/2 [==============================] - 0s 0us/step - loss: 11.4970 - acc: 0.5000
Epoch 3/10
2/2 [==============================] - 0s 0us/step - loss: 11.4664 - acc: 0.5000
Epoch 4/10
2/2 [==============================] - 0s 498us/step - loss: 11.4430 - acc: 0.5000
Epoch 5/10
2/2 [==============================] - 0s 496us/step - loss: 11.4243 - acc: 1.0000
Epoch 6/10
2/2 [==============================] - 0s 483us/step - loss: 11.4087 - acc: 1.0000
Epoch 7/10
2/2 [==============================] - 0s 1ms/step - loss: 11.3954 - acc: 1.0000
Epoch 8/10
2/2 [==============================] - 0s 997us/step - loss: 11.3840 - acc: 1.0000
Epoch 9/10
2/2 [==============================] - 0s 1ms/step - loss: 11.3740 - acc: 1.0000
Epoch 10/10
2/2 [==============================] - 0s 995us/step - loss: 11.3653 - acc: 1.0000
推荐阅读
- ember.js - 如何从 Ember 中的控制器重新加载路由模型?
- javascript - 如果可能,我如何使用 .map 方法对数组中的项目进行排序
- c - 如何使用 C 读取视频文件
- javascript - 如何从 lua 中的多维数组访问值?
- jquery - 使用 jQuery Knob 我想在值增加时切换类
- asp.net - 如何在 asp.net webform 应用程序中正确添加 ajax 组合框
- spring-boot - 如何在不使用@Query Annotation的情况下从方法名称中仅检索Springboot中Crud Repository中单个列的值?
- laravel - Docker 上的 Laravel:[2002] 连接被拒绝
- database - 折扣 UML 图
- java - 从java中的时间戳按月分组火花数据集