python - 张量切片会丢失 TensorFlow 中的形状信息
问题描述
我正在尝试动态切片张量以自动调整其形状以进行下一次迭代。然而,我意识到,当在图形模式下切片时,张量的形状信息会丢失,因此我无法进一步对其应用某些操作,这需要知道给定张量的形状。下面我附上了一个示例代码,在我的具体示例中,该opt_with_slicing
函数vectorized_map
在一个更大的函数中定义,该函数负责自动区分。由于原始函数太大而无法包含在此处,因此我对其进行了相应的简化;
a = tf.constant(np.linspace(0.,10.,11,endpoint=True)[::-1])
b = tf.ones((2,10))
def opt_with_slicing(x, some_cutoff: float):
a, b = x
new_size = tf.math.count_nonzero(
tf.cast(a >= some_cutoff, dtype=tf.int32), dtype=tf.int32
)
tf.print(f"new size {new_size}, initial size {a.get_shape()}")
test1 = b[:, :new_size]
test2 = tf.slice(b, [0, 0], [b.get_shape()[0], new_size])
tf.print(f"test1 shape {test1.get_shape()}, test2 shape {test2.get_shape()}")
return test1, test2
tf.function(opt_with_slicing)([a,b], 5.)
# Output:
# new size Tensor("count_nonzero/Cast_1:0", shape=(), dtype=int32), initial size (11,)
# test1 shape (2, None), test2 shape (2, None)
# (<tf.Tensor: shape=(2, 6), dtype=float32, numpy=
# array([[1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1.]], dtype=float32)>,
# <tf.Tensor: shape=(2, 6), dtype=float32, numpy=
# array([[1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1.]], dtype=float32)>)
正如您从打印出来的 和 的形状信息中看到的那样test1
,test2
由于这是一个动态操作,我无法知道new_size
执行之前的形状信息。有没有办法在不破坏图形模式的情况下恢复函数的形状信息?
PS:我也试过了boolean_mask
;
mask = tf.greater_equal(a, some_cutoff)
masked_shape = tf.boolean_mask(a, mask).get_shape()[0]
但masked_shape
事实证明None
也是如此。
系统信息:
- TensorFlow v2.5.0
- Python v3.8.2
解决方案
推荐阅读
- c# - ListView 使用向上和向下按钮移动项目 .NET Framework
- android - AR图像识别从原生开发到跨平台
- docker - docker container ls --size 不准确,如何获得准确的容器大小?
- angular - 带有视图模型前缀 (vm) 的 Angular 8 UI 模型绑定
- tfs - 将文件夹检入 TFS 服务器后从 TFS 中删除
- python - 从 python 控制台清除多行
- python - 让 Python 程序运行另一个 Python 程序并让它们同时运行?
- javascript - Electron 桌面应用程序在它应该是离线应用程序时尝试 TCP 连接
- java - 如何断言对象列表具有一组具有特定值的属性
- c# - 使用 EPPlus 将 Web 表单 DataGrid 导出到 excel