python - 单层网络时序数据示例
问题描述
我正在阅读Thomas P. Trappenberg在机器学习基础中的示例。这本书没有提供承诺的实际 jupyter notebook 文件,所以我通过复制和阅读示例代码来学习。
这是第 9 章中关于分析序列数据的介绍性示例。我们想要一个具有单个隐藏层的简单网络来正确预测正弦波。
# sine sequence
import numpy as np
import matplotlib.pyplot as plt
from keras import models, layers, optimizers, datasets, utils, losses
# sine data with 10 steps/cycle
seq = np.array([np.sin(2*np.pi*i/10) for i in range(10)])
print(seq)
num_seq = 200
x_train = np.array([])
y_train = np.array([])
for i in range(num_seq):
ran = np.random.randint(10)
x_train = np.append(x_train, seq[ran])
y_train = np.append(y_train, seq[np.mod(ran+1, 10)])
x_test = np.array(seq)
y_test = np.array(np.roll(seq, -1))
到目前为止,我看到我们从 0 到 10 域的正弦波中挑选了 200 个点。x_train 包含来自正弦函数的 200 个值,y_train 包含我们希望预测的序列中的下一个值。
下面的代码应该使用序列中前两个点的知识来预测正弦函数。这是我在运行代码时发现错误的地方。
# MLP2
inputs = layers.Input(shape = (2, ))
h = layers.Dense(2, activation = 'relu')(inputs)
outputs = layers.Dense(1, activation = 'tanh')(h)
model = models.Model(inputs, outputs)
model.compile(loss = 'mean_squared_error', optimizer = 'adam')
print(model.summary())
model.fit(x_train, y_train, epochs = 1000, batch_size = 100, verbose = 0)
# evaluate
y_pred = model.predict(x_test, batch_size = 10, verbose = 1)
plt.plot(y_test, 'x')
plt.plot(y_pred, 'o')
当我运行此代码时,从行
model.fit(x_train, y_train, epochs = 1000, batch_size = 100, verbose = 0)
我明白了
ValueError:检查输入时出错:预期 input_11 的形状为 (2,) 但得到的数组的形状为 (1,)
我有点明白,因为 x_train 和 y_train 都具有形状 (1, ),因为它们被理解为向量。当我复制代码时
inputs = layers.Input(shape = (2, ))
我盲目地认为“也许这就是 Keras 是如何编写来理解序列的过去两个条目的”,因为这是我第一次学习机器学习,而且我对 Keras 并不熟悉。
您是否看到示例代码中的错误、结构错误或我犯的人为错误?
(我有一个直接的后续问题,因为下一个示例介绍了 RNN,它以一段代码开头
# RNN
x_train=np.reshape(x_train, (200, 2, 1) )
x_test=np.reshape(x_test, (10, 2, 1) )
这不起作用,因为当前 x_train 只有 200 个条目,它不能转换为形状 (200, 2, 1)。我想如果有人可以回答我最初的问题,也许这个后续问题会自动解决。)
感谢您的时间。
解决方案
您只需要编辑输入形状大小:
inputs = layers.Input(shape = (1, ))
然后尝试运行代码,希望对您有所帮助!
推荐阅读
- javascript - 在 Javascript 中更有效地搜索大型数组?
- c - 优先队列未按升序插入元素
- arrays - 如何在 d3.js 中引用来自 CSV 的数据
- r - 如何使用 Dygraphs 包绘制字符类每月平均值?
- python - tf.dynamic_partition 获取非空张量
- c# - 在 EF Core 中设置属性约定?
- c# - 如何删除字符串内的特定空白
- python - Python str() 未被列表理解调用
- java - @Transactional 在我的 Spring Boot 应用程序中无效
- javascript - 创建用于不同服务器的 javascript 文件