python - 垫后的张量流切片不起作用,而没有垫它可以
问题描述
使用张量流 1.8
我有一个 RNN,我试图在其中填充然后将特征切成可变长度的标准化句子
# param.batch_size = 32 params.max_doc_len = 10
# features is of shape [32, 800+]
features = tf.sparse_tensor_to_dense(features, default_value=0)
features = tf.Print(features, [tf.shape(features), features], "Features after sparse2dense")
features = tf.pad(features, tf.constant([[0, 0], [0, params.max_doc_len]]), "CONSTANT")
features = tf.Print(features, [tf.shape(features), features], "Features after pad")
# same output with
# features = features[:, :params.max_doc_len]
features = tf.strided_slice(features, [0,0], [params.batch_size, params.max_doc_len], [1,1])
features = tf.Print(features, [tf.shape(features), features], "Features after pad and drop")
但是在切片时我得到了错误的尺寸:
Features after sparse2dense[32 858][[1038 5 104]...]
Features after pad[32 868][[1038 5 104]...]
Features after pad and drop[10 10][[1038 5 104]...]
如果我删除垫操作,我会得到正确的输出,如下所示:
Features after sparse2dense[32 858][[1038 5 104]...]
Features after pad and drop[32 10][[1038 5 104]...]
最糟糕的是,相同的代码在笔记本中运行良好(版本匹配)
t = tf.constant([[1, 2, 3], [4,3,2],[1, 2, 3], [4,3,2],[1, 2, 3],[9, 9, 9]])
MAX_DOC_LEN = 5
paddings = tf.constant([[0, 0], [0, MAX_DOC_LEN]])
padded = tf.pad(t, paddings, "CONSTANT")
cropped = padded[:, :MAX_DOC_LEN]
with tf.Session() as sess:
print(tf.shape(t).eval()) # [6 3]
print(tf.shape(padded).eval()) # [6 8]
print(tf.shape(cropped).eval()) # [6 5]
现在的问题是我做错了什么?
解决方案
如果我理解正确,您正在尝试用零填充每一行,以便长度固定。原来有一个非常简单的解决方案,就在你的代码的第一行(注意我已经替换tf.sparse_tensor_to_dense()
为tf.sparse_to_dense()
- 这些是不同的!):
filter = tf.less( features.indices[ :, 1 ], params.max_doc_len )
features = tf.sparse_retain( features, filter )
features = tf.sparse_to_dense( sparse_indices = features.indices,
output_shape = ( params.batch_size, params.max_doc_len ),
sparse_values = features.values,
default_value = 0 )
前两行只是实现了一个过滤器来丢弃任何超出 的值max_doc_len
,因此基本上截断了所有行。
这里的主要思想是tf.sparse_to_dense()
允许手动指定我们想要的结果张量的形状,并且无论如何它都会用零填充其余部分。因此,这一行可以完成您的代码部分的用途。
PS 尽管如此,TensorFlow 中可能存在的错误仍然存在,但我无法在任何地方重现该问题。
推荐阅读
- django - 查询集在视图中获取并传递给模板
- php - magento 2:map.xml 没有忽略字段
- php - 在 laravel 中覆盖/更改供应商/主题文件的最佳实践
- c - 哪些优化技术应用于总结简单算术序列的 Rust 代码?
- ios - 为什么我从 AppStore 下载的 Xcode10.0 存档的应用程序在 iOS9.0 中崩溃?
- react-native - 如何在本机反应中记忆流式传输?
- git - 我想在 webhook 触发后从我的 rsync jenkins 作业中排除一些文件
- go - GORM JOIN 和结果
- ios - 如何在swift 3中解决标签中的重叠最后一个值
- mysql - 在MYSQL中合并两列