python - 学习正弦函数似乎在 ANN (Keras) 中采用了过多的参数
问题描述
我一直在尝试对不同的函数逼近方法进行一些研究,我尝试的第一个方法是使用 ANN(人工神经网络)。代码如下 -
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from sklearn.preprocessing import MinMaxScaler
X = np.linspace(0.0 , 2.0 * np.pi, 20000).reshape(-1, 1)
Y = np.sin(X)
x_scaler = MinMaxScaler()
y_scaler = MinMaxScaler()
X = x_scaler.fit_transform(X)
Y = y_scaler.fit_transform(Y)
plt.plot(X, Y)
plt.show()
inp = Input(shape=(20000, 1))
x = Dense(32, activation='relu')(inp)
x = Dense(64, activation='relu')(x)
x = Dense(128, activation='relu')(x)
x = Dense(256, activation='relu')(x)
predictions = Dense(1, activation='linear')(x)
model = Model(inp, predictions)
model.compile(loss='mse', optimizer='adam')
model.summary()
X = X.reshape((-1, 20000, 1))
Y = Y.reshape((-1, 20000, 1))
history = model.fit(X, Y, epochs=500, batch_size=32, verbose=2)
X_test = np.linspace(0.0 , 2.0 * np.pi, 20000).reshape(-1, 1)
X_test.shape
X_test = x_scaler.transform(X_test)
X_test = X_test.reshape((-1, 20000, 1))
res = model.predict(X_test, batch_size=32)
res = res.reshape((20000, 1))
res_rscl = y_scaler.inverse_transform(res)
Y_rscl = y_scaler.inverse_transform(Y.reshape(20000, 1))
plt.subplot(211)
plt.plot(res_rscl, label='ann')
plt.plot(Y_rscl, label='train')
plt.xlabel('#')
plt.ylabel('value [arb.]')
plt.legend()
plt.subplot(212)
plt.plot(Y_rscl - res_rscl, label='diff')
plt.legend()
plt.show()
情节如下 -
正如我们所看到的,它确实非常接近这种架构的正弦曲线。但是,我不确定我做的是否正确。对我来说,我需要43,777
参数来拟合正弦曲线看起来很奇怪。也许我错了。然而,看着这个 R 代码(我根本不知道 R,但我猜 ANN 比我拥有的要小得多)让我想知道更多。
我的问题 - 我的方法对吗?我应该改变一些东西以减少参数的数量吗?或者,正弦是一个困难的函数,对于人工神经网络来说,它需要大量的参数来近似它是正常的吗?
这可能是一个开放式的问题,但我真的很感激你能指出我的任何方向以及我犯的任何错误,你可以告诉我。
注意 -这个问题表明数据的循环性质对 ANN 来说是一件困难的事情。我还想知道这是否真的如此,以及这是否是 ANN 采用如此多参数的原因。
解决方案
推荐阅读
- python - 在 TensorFlow/Keras 中加载 TIFF 图像数据集
- python - 如何在排行榜中对分数进行排序?
- java - 二叉搜索树布尔返回类型
- python - 使用 Nans 对 griddata 进行 2d 插值会导致 NaN,即使在执行插值之前屏蔽掉原始数据中的 Nans 之后?
- php - 在 magento 2.4 中扩展 grahpql 的模式
- powershell - 使用 Powershell 自动启动/停止 Azure Function App
- javascript - 在我的本地存储 Jquery 代码上添加了新字段
- html - 在 Angular2 从 JSON 生成复杂的 html
- java - 收到 Hikari 池初始化错误
- python-3.x - 如何用装饰器包装 func.__code__.co_filename?