python - 在函数中使用 keras 神经网络
问题描述
我正在尝试使用 keras 从论文中实现一种算法,他们在其中训练神经网络以使用有限数量的数据点来逼近数学函数 f(x)。我希望神经网络的输入为 x,输出形式为 f(x) = 1 + xN(x),其中 N(x) 是来自最终密集层的值。
我知道如何使它对输出 f(x) = N(x) 起作用,但我只是不知道如何调整网络以获得 f(x) = 1 + xN(x)。有人能帮我吗?
这是我当前的代码
from keras.layers import Input, Dense, Add, Multiply
from keras.models import Model
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import time
def f(x):
return x**2
Xtrain = np.linspace(0, 1, 10)
ytrain = np.array([f(x) for x in Xtrain])
X = np.linspace(0, 2, 100)
y = np.array([f(x) for x in X])
input = Input(shape=(1,))
init = np.ones(shape=(10, 1))
init = K.variable(init)
hidden = input
hidden = Dense(8, activation='relu')(hidden)
out = Dense(1, activation='linear')(hidden)
out = Add()([init, Multiply()([out, input])])
model = Model(inputs=input, outputs=out)
model.compile(loss='mean_squared_error', optimizer="adam")
tic = time.perf_counter()
model.fit(Xtrain, ytrain, epochs=1000, verbose=1)
toc = time.perf_counter()
print(f"Training time: {toc - tic:0.4f} seconds")
prediction = model.predict(X)
prediction = prediction.reshape((100,))
plt.figure(figsize=(10,5))
plt.plot(X, y, color='red', label='Analytical solution')
plt.plot(X, prediction, color='black', label = 'Prediction')
plt.scatter(Xtrain, ytrain, color='blue', label='Training points')
plt.legend()
plt.show()
plt.tight_layout()
在线崩溃
out = Add()([init, Multiply()([out, input])])
解决方案
添加层在两个层之间以及一个层和一个数字/ndarray 之间工作。
你可以像这样使用它:
init=np.ones(shape=(10, 1))
inp = Input(shape=(1,))
hidden = Dense(8, activation='relu')(inp)
out = Dense(1, activation='linear')(hidden)
mul=Multiply()([out, inp])
out = Add()([init, mul])
model = Model(inputs=inp, outputs=out)
model.compile(loss='mean_squared_error', optimizer="adam")
我检查了它并且它有效。
顺便说一句,input
是一个内置函数,除非你想使用它,否则我不建议使用它。
推荐阅读
- flask - 为什么照片加载不出来?
- java - 仓库找到 3 条记录,结果列表返回 6 条记录
- spring-boot - 使用 Spring Boot 保护 Web 应用程序不起作用
- java - 将值与 ArrayList Java 的特定对象进行比较时出现问题
- html - 是否可以不触发:悬停在 ::before 或 ::after 上?
- python - 根据每天在 df 中更改的列在 df 中执行计算
- azure - 来自 MS Teams 对话机器人的流量未通过 Azure 应用程序网关到达机器人
- javascript - 我希望以下正则表达式返回 16 个字符
- c++ - 对类型“const int *”的非 const 左值引用不能绑定到不相关类型“int *”的值
- json - 为 jq 命令提供一个非常大的参数来过滤键