python - 为什么简单回归的 TensorFlow 结果与其输入相去甚远?
问题描述
我有一个简单的回归训练数据,如下所示。我想在 TensorFlow 中训练网络,然后[1 0 1]
再次输入(与示例 3 相同)到网络中,这应该给我一个接近 1 的值(比如 0.99)。
现在这是我的 TensorFlow 代码(在 Python 3 中)。我使用了一个线性层,然后是一个 Sigmoid。我使用均方损失。请注意,在最后几行中,我输入[1 0 1]
了测试模型的预测能力。我只是得到0.5015
了,这与我的期望相去甚远(即0.99
)。
版本 1:TensorFlow 代码:
import tensorflow as tf
import numpy as np
batch_xs=np.array([[0,0,1],[1,1,1],[1,0,1],[0,1,1]])
batch_ys=np.array([[0],[1],[1],[0]])
x = tf.placeholder(tf.float32, [None, 3])
W = tf.Variable(tf.zeros([3, 1]))
b = tf.Variable(tf.zeros([1]))
y = tf.nn.sigmoid(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 1])
mean_square_loss = tf.reduce_mean(tf.square(y_ - y))
learning_rate = 0.05
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(mean_square_loss)
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
x0=np.array([[1.,0.,1.]])
x0=np.float32(x0)
y0=tf.nn.sigmoid(tf.matmul(x0,W) + b)
print('%.15f' % sess.run(y0))
为什么结果与预期值相差甚远?如果我只是使用 Numpy 而不是 TensorFlow,下面 9 行代码可以实现0.9936
.
版本 2:Numpy 代码:
from numpy import exp, array, random, dot
training_set_inputs = array([[0, 0, 1], [1, 1, 1], [1, 0, 1], [0, 1, 1]])
training_set_outputs = array([[0, 1, 1, 0]]).T
random.seed(1)
synaptic_weights = 2 * random.random((3, 1)) - 1
for iteration in range(10000):
output = 1 / (1 + exp(-(dot(training_set_inputs, synaptic_weights))))
synaptic_weights += dot(training_set_inputs.T, (training_set_outputs - output) * output * (1 - output))
print(1 / (1 + exp(-(dot(array([1, 0, 1]), synaptic_weights)))))
如何修复版本 1 中的 TensorFlow 代码,使结果接近 0.99?非常感谢!
解决方案
在您的 tensorflow 代码中sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
更新/训练您的权重。请注意,您只运行一次,学习率为 0.05。但是在你的 numpy 代码中,你运行了 10000 次迭代,这相当于做
for i in range(10000):
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
结果应该是 0.95 左右。如果您将学习率提高到 1,就像在 numpy 代码上所做的那样,您应该得到预期的行为(0.99)。
推荐阅读
- r - 使用 RQDA 库的 R (3.6.1) 出现致命错误
- javascript - 将 .js 文件作为文本导入并在 react-native expo 中的 WebView 中使用
- java - 在类路径中找不到类:Open_MRS_POM.Open_MRS_Engine
- javascript - 禁用 CSS 类的视觉效果
- d - 使用 lambda 调用时,编译时折叠会导致错误
- windows - Kafka客户端绑定IP(二级网卡)
- maven - Intellij 未显示多模块 Maven 项目的子模块
- python - 是否可以通过函数传递带有前导零的整数?
- hive - 在 DataFrame 中不能有调用集合操作的地图类型列
- jquery - 如何在 Django 中实现 toast 消息?