python - “[...]”用tensorflow在磁盘上加载数据是什么意思
问题描述
我从 tensorflow 和 Machine Learning 开始,但我有一本书非常有用。我想实现小批量梯度下降。我完全按照他们所说的去做,但没有用。
它被解释为:“最后,在执行阶段,一个一个地获取小批量,然后在评估依赖于它们中的任何一个的节点时,通过 feed_dict 参数提供 X 和 Y 的值。”
我正在使用Jupyter notebook
,tensorflow 1.3.0.
这是我尝试过的:
n_epochs=1000
learning_rate=0.0001
#X=tf.constant(housing_data_plus_bias,dtype=tf.float32,name="X")
X=tf.placeholder(tf.float32,shape=(None,n+1),name="X")
Y=tf.placeholder(tf.float32,shape=(None,1),name="Y")
batch_size=100
n_batches=int(np.ceil(m/batch_size))
#Y=tf.constant(housing.target.reshape(-1,1),dtype=tf.float32,name="Y")
theta=tf.Variable(tf.random_uniform([n+1,1],-1.0,1.0),name="theta")
y_pred=tf.matmul(X,theta,name="predictions") #eq 1.4
error=y_pred - Y
mse=tf.reduce_mean(tf.square(error),name="mse") #eq 1.5
#gradients=tf.gradients(mse,[theta])[0]
gradients= (2/(m*mse) ) * tf.matmul(tf.transpose(X),error)
training_op = tf.assign(theta,theta - learning_rate * gradients)
def fetch_batch(epoch,batch_index,batch_size):
[...] #Load DATA FROM DISK (SEE NOTEBOOK)
return X_batch, Y_batch
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(n_epochs):
for batch_index in range(n_batches):
X_batch,Y_batch=fetch_batch(epoch,batch_index,batch_size)
sess.run(training_op,feed_dict={X:X_batch,Y:Y_batch})
best_theta=theta.eval()
print(best_theta)
这是错误:
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-41-f199ccce6734> in <module>
27 for epoch in range(n_epochs):
28 for batch_index in range(n_batches):
---> 29 X_batch,Y_batch=fetch_batch(epoch,batch_index,batch_size)
30 sess.run(training_op,feed_dict={X:X_batch,Y:Y_batch})
31
<ipython-input-41-f199ccce6734> in fetch_batch(epoch, batch_index, batch_size)
19 def fetch_batch(epoch,batch_index,batch_size):
20 [...]
---> 21 return X_batch, Y_batch
22
23 init=tf.global_variables_initializer()
NameError: name 'X_batch' is not defined
所以我的问题是,我应该用那个 [...] 做什么,它是从磁盘加载数据的真正方法还是应该用其他东西替换它?
解决方案
推荐阅读
- sql - 数据类型 TIMESTAMP_LTZ 在雪花表中显示的时间值不正确
- python - Django通过订单模型获取产品总价
- ios - 有没有办法使用几何阅读器将我的 SwiftUI 图像的宽度定义为相对于屏幕尺寸?
- java - 使用 netbeans 11 配置 JavaFX
- c# - LINQ 如何在收集结束时删除空元素
- javascript - 如何在 vuejs 中构建一些包含 .vue 文件的文件夹?
- reactjs - 如何修复 npx create-react-app 错误?
- java - 如何在 Android 上定义圆形 ImageView?
- angular - 400 错误请求“将值 {null} 转换为类型‘System.Int32’时出错。路径‘id’,第 1 行,位置 10。” POST 请求
- reactjs - React WordPress Customizer 自定义控件