python - 在张量流中输入一个带占位符的整数?
问题描述
我想batch_size
在 Tensorflow 中输入一个整数作为占位符。但它不充当整数。考虑以下示例:
import tensorflow as tf
max_length = 5
batch_size = 3
batch_size_placeholder = tf.placeholder(dtype=tf.int32)
mask_0 = tf.one_hot(indices=[0]*batch_size_placeholder, depth=max_length, on_value=0., off_value=1.)
mask_1 = tf.one_hot(indices=[0]*batch_size, depth=max_length, on_value=0., off_value=1.)
# new session
with tf.Session() as sess:
feed = {batch_size_placeholder : 3}
batch, mask0, mask1 = sess.run([
batch_size_placeholder, mask_0, mask_1
], feed_dict=feed)
当我打印 的值时,batch
我有以下内容:mask0
mask1
print(batch)
>>> array(3, dtype=int32)
print(mask0)
>>> array([[0., 1., 1., 1., 1.]], dtype=float32)
print(mask1)
>>> array([[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.]], dtype=float32)
确实我认为mask0
并且mask1
必须是相同的,但似乎 Tensorflow 并没有将batch_size_placeholder
其视为整数。我相信这将是一个张量,但无论如何我可以在我的计算中将它用作整数吗?
无论如何我可以解决这个问题吗?仅供参考,我仅用作示例,我想在训练期间在我的代码中运行训练/验证,在训练和验证步骤中tf.one_hot
,我需要许多具有不同值的其他计算。batch_size
任何帮助,将不胜感激。
解决方案
在纯 python 用法中,[0]*3
将是[0,0,0]
. 但是,batch_size_placeholder
是一个占位符,在图形执行期间,它将是一个张量。[0]*tensor
将被视为张量乘法。在您的情况下,它将是一个值为 0 的一维张量。要正确使用batch_size_placeholder
,您应该创建一个与 具有相同长度的张量batch_size_placeholder
。
mask_0 = tf.one_hot(tf.zeros(batch_size_placeholder, dtype=tf.int32), depth=max_length, on_value=0., off_value=1.)
它将具有与 相同的结果mask_1
。
一个简单的例子来显示差异。
batch_size_placeholder = tf.placeholder(dtype=tf.int32)
a = [0]*batch_size_placeholder
b = tf.zeros(batch_size_placeholder, dtype=tf.int32)
with tf.Session() as sess:
print(sess.run([a, b], feed_dict={batch_size_placeholder : 3}))
# [array([0], dtype=int32), array([0, 0, 0], dtype=int32)]
推荐阅读
- c# - 使用 LINQ 批量更新
- python - 如何在 csv 文件中搜索字符串?
- javascript - 使用 qunit 和 sinon 测试 jquery click
- c - GetProcessTimes ExitTime 来自未提升的进程
- python - 无法将大小为 12212 的数组重塑为形状 (400,400,180)
- oracle - 即使我在 OracleConfiguration 中定义了路径,Oracle 也会给出错误 12154
- asp.net - 如何以编程方式使 GridView 中的字段与 AutoGenerateEditButton 一起可编辑?
- parsing - 朱莉娅:可以(应该)在“解析时间”捕获这种类型的错误吗?
- javascript - RxJS:为什么内部可观察首先触发?
- r - 限制data.frame中的列超过条件