python - tf.variable 的扩充大小
问题描述
我autoencoder
通过提供 2 个存储以下内容的占位符来训练 an:
x1 = [x1]
X = [x1,x2,x3...xn]
它认为:
y1 = W*x1 + b_encoding1
因此,我有一个名为b_encoder1
(b) 的变量(当我打印它时,我得到<tf.Variable 'b_encoder1:0' shape=(10,) dtype=float32_ref>
:)
但它也认为:
Y = W*X + b_encoding1
第二个的大小b_encoding1
必须(10,n)
是(10,)
. 我怎样才能增强它并传递它tensorflow
?
Y = tf.compat.v1.nn.xw_plus_b(X, W1, b_encoder1, name='Y')
整个代码如下所示:
x1 = tf.compat.v1.placeholder( tf.float32, [None,input_shape], name = 'x1')
X = tf.compat.v1.placeholder( tf.float32, [None,input_shape,sp], name = 'X')
W1 = tf.Variable(tf.initializers.GlorotUniform()(shape=[input_shape,code_length]),name='W1')
b_encoder1 = tf.compat.v1.get_variable(name='b_encoder1',shape=[code_length],initializer=tf.compat.v1.initializers.zeros(), use_resource=False)
K = tf.Variable(tf.initializers.GlorotUniform()(shape=[code_length,code_length]),name='K')
b_decoder1 = tf.compat.v1.get_variable(name='b_decoder1',shape=[input_shape],initializer=tf.compat.v1.initializers.zeros(), use_resource=False)
y1 = tf.compat.v1.nn.xw_plus_b(x1, W1, b_encoder1, name='y1')
Y = tf.compat.v1.nn.xw_plus_b(X, W1, b_encoder1, name='Y')
我还声明了损失函数等,然后训练:
with tf.compat.v1.Session() as sess:
sess.run(tf.compat.v1.global_variables_initializer())
for epoch_i in range(epochs):
for batch_i in range(number_of_batches):
batch_data = getBatch(shuffled_data, batch_i, batch_size)
sess.run(optimizer, feed_dict={x1: batch_data[:,:,0], X: batch_data})
train_loss = sess.run(loss, feed_dict={x1: aug_data[:,:,0], X: aug_data})
print(epoch_i, train_loss)
解决方案
您可以将X
其视为一批x
. X
可以接受任意数量的样本:
import tensorflow as tf
import numpy as np
X = tf.placeholder(shape=(None, 100), dtype=tf.float32)
W = tf.get_variable('kernel', [100,10])
b = tf.get_variable('bias',[10])
Y = tf.nn.xw_plus_b(X, W,b, name='Y')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # tf version < 1.13
out = sess.run(Y, {X: np.random.rand(128, 100)}) # here n=128
请注意,b
无论 n 的值如何,偏差的维度仍然是 10-D。
推荐阅读
- r - 现在到 ggplot N-th 变量?
- javascript - 如何在分针中编辑时钟代码以 30 的间隔变化
- c - 什么是(uint32_t*)?
- sql - BigQuery 过滤列是否具有特定值,然后透视结果
- arrays - 如何使用 id 匹配的 json 对 Mongo DB 对象数组进行批量更新
- javascript - 使用反应路由器,如何使用路由配置传递组件道具?
- pandas - KeyError: 0 将 bs4 xml 转换为 pandas df
- python - 将数组添加到熊猫数据框
- python - Python cv2 不写视频
- algorithm - 保证图的边服从三角不等式