首页 > 解决方案 > tf.variable 的扩充大小

问题描述

autoencoder通过提供 2 个存储以下内容的占位符来训练 an:

x1 = [x1]

X = [x1,x2,x3...xn]

它认为:

y1 = W*x1 + b_encoding1

因此,我有一个名为b_encoder1(b) 的变量(当我打印它时,我得到<tf.Variable 'b_encoder1:0' shape=(10,) dtype=float32_ref>:)

但它也认为:

Y = W*X + b_encoding1

第二个的大小b_encoding1必须(10,n)(10,). 我怎样才能增强它并传递它tensorflow

Y = tf.compat.v1.nn.xw_plus_b(X, W1, b_encoder1, name='Y')

整个代码如下所示:

x1 = tf.compat.v1.placeholder( tf.float32, [None,input_shape], name = 'x1')
X = tf.compat.v1.placeholder( tf.float32, [None,input_shape,sp], name = 'X')

W1 = tf.Variable(tf.initializers.GlorotUniform()(shape=[input_shape,code_length]),name='W1')
b_encoder1 = tf.compat.v1.get_variable(name='b_encoder1',shape=[code_length],initializer=tf.compat.v1.initializers.zeros(), use_resource=False)
K = tf.Variable(tf.initializers.GlorotUniform()(shape=[code_length,code_length]),name='K')
b_decoder1 = tf.compat.v1.get_variable(name='b_decoder1',shape=[input_shape],initializer=tf.compat.v1.initializers.zeros(), use_resource=False)

y1 = tf.compat.v1.nn.xw_plus_b(x1, W1, b_encoder1, name='y1')
Y = tf.compat.v1.nn.xw_plus_b(X, W1, b_encoder1, name='Y')

我还声明了损失函数等,然后训练:

with tf.compat.v1.Session() as sess:

    sess.run(tf.compat.v1.global_variables_initializer())

    for epoch_i in range(epochs):

        for batch_i in range(number_of_batches):

            batch_data = getBatch(shuffled_data, batch_i, batch_size)
            sess.run(optimizer, feed_dict={x1: batch_data[:,:,0], X: batch_data})

        train_loss = sess.run(loss, feed_dict={x1: aug_data[:,:,0], X: aug_data})
        print(epoch_i, train_loss)

标签: pythontensorflow

解决方案


您可以将X其视为一批x. X可以接受任意数量的样本:

import tensorflow as tf
import numpy as np

X = tf.placeholder(shape=(None, 100), dtype=tf.float32)
W = tf.get_variable('kernel', [100,10])
b = tf.get_variable('bias',[10])
Y = tf.nn.xw_plus_b(X, W,b, name='Y')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())  # tf version < 1.13
    out = sess.run(Y, {X: np.random.rand(128, 100)})  # here n=128

请注意,b无论 n 的值如何,偏差的维度仍然是 10-D。


推荐阅读